diff --git a/.gitlab/issue_templates/release.md b/.gitlab/issue_templates/release.md index a95cf033eed919339c6c1734638542c3e0cdbc57..c8289cdd1be4ab02a61c6feb0db9db7cb6ca40d3 100644 --- a/.gitlab/issue_templates/release.md +++ b/.gitlab/issue_templates/release.md @@ -14,6 +14,7 @@ vX.Y.Z * [ ] Adjust `changelog.md` (see template for changelog) * [ ] Update version number in `mlair/__ init__.py` * [ ] Create new dist file: `python3 setup.py sdist bdist_wheel` +* [ ] Add new dist file `mlair-X.Y.Z-py3-none-any.whl` to git * [ ] Update file link `distribution file (current version)` in `README.md` * [ ] Update file link in `docs/_source/installation.rst` * [ ] Commit + push diff --git a/CHANGELOG.md b/CHANGELOG.md index b11a169c854465c5ea932f00f5da5a1688df7c18..34795b8333df846d5383fc2d8eca4b40517aab73 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,25 @@ # Changelog All notable changes to this project will be documented in this file. -## v1.4.0 - 2021-07-27 - <release description> +## v1.5.0 - 2021-11-11 - new uncertainty estimation + +### general: +* introduces method to estimate sample uncertainty +* improved multiprocessing +* last release with tensorflow v1 support + +### new features: +* test set sample uncertainty estmation during postprocessing (#333) +* support of Kolmogorov Zurbenko filter for data handlers with filters (#334) + +### technical: +* new communication scheme for multiprocessing (#321, #322) +* improved error reporting (#323) +* feature importance returns now unaggregated results (#335) +* error metrics are reported for all competitors (#332) +* minor bugfixes and refacs (#330, #326, #329, #325, #324, #320, #337) + +## v1.4.0 - 2021-07-27 - new model classes and data handlers, improved usability and transparency ### general: * many technical adjustments to improve usability and transparency of MLAir diff --git a/HPC_setup/create_runscripts_HPC.sh b/HPC_setup/create_runscripts_HPC.sh index 5e37d820ae1241c09c1c87c141bdcf005044a3b7..730aa52ef42144826bd000d88c0fc81c9d508de0 100755 --- a/HPC_setup/create_runscripts_HPC.sh +++ b/HPC_setup/create_runscripts_HPC.sh @@ -85,7 +85,7 @@ source venv_${hpcsys}/bin/activate timestamp=\`date +"%Y-%m-%d_%H%M-%S"\` -export PYTHONPATH=\${PWD}/venv_${hpcsys}/lib/python3.6/site-packages:\${PYTHONPATH} +export PYTHONPATH=\${PWD}/venv_${hpcsys}/lib/python3.8/site-packages:\${PYTHONPATH} srun --cpu-bind=none python run.py --experiment_date=\$timestamp EOT @@ -102,6 +102,7 @@ cat <<EOT > ${cur}/run_${hpcsys}_batch.bash #SBATCH --output=${hpclogging}mlt-out.%j #SBATCH --error=${hpclogging}mlt-err.%j #SBATCH --time=08:00:00 +#SBATCH --gres=gpu:4 #SBATCH --mail-type=ALL #SBATCH --mail-user=${email} @@ -110,7 +111,7 @@ source venv_${hpcsys}/bin/activate timestamp=\`date +"%Y-%m-%d_%H%M-%S"\` -export PYTHONPATH=\${PWD}/venv_${hpcsys}/lib/python3.6/site-packages:\${PYTHONPATH} +export PYTHONPATH=\${PWD}/venv_${hpcsys}/lib/python3.8/site-packages:\${PYTHONPATH} srun --cpu-bind=none python run_HPC.py --experiment_date=\$timestamp EOT diff --git a/HPC_setup/mlt_modules_juwels.sh b/HPC_setup/mlt_modules_juwels.sh index e72c0f63141bad4bab442e18b93d9fbb37adb287..ffacfe6fc45302dfa60b108ca2493d9a27408df1 100755 --- a/HPC_setup/mlt_modules_juwels.sh +++ b/HPC_setup/mlt_modules_juwels.sh @@ -11,7 +11,7 @@ module use $OTHERSTAGES ml Stages/2020 ml GCCcore/.10.3.0 -# ml Jupyter/2021.3.1-Python-3.8.5 +ml Jupyter/2021.3.1-Python-3.8.5 ml Python/3.8.5 ml TensorFlow/2.5.0-Python-3.8.5 ml SciPy-Stack/2021-Python-3.8.5 diff --git a/HPC_setup/setup_venv_hdfml.sh b/HPC_setup/setup_venv_hdfml.sh index 7e8334dd26874514c4fcfa686c49eeb7e1cabf0d..dbe001d587f3c0333a7e06cbe41c93759a53efda 100644 --- a/HPC_setup/setup_venv_hdfml.sh +++ b/HPC_setup/setup_venv_hdfml.sh @@ -31,14 +31,14 @@ echo "##### FINISH INSTALLING requirements_HDFML_additionals.txt #####" # pip install --ignore-installed matplotlib==3.2.0 # pip install --ignore-installed pandas==1.0.1 # pip install --ignore-installed statsmodels==0.11.1 -pip install --ignore-installed tabulate -pip install -U typing_extensions +# pip install --ignore-installed tabulate +# pip install -U typing_extensions # see wiki on hdfml for information oh h5py: # https://gitlab.version.fz-juelich.de/haf/Wiki/-/wikis/HDF-ML%20System export CC=mpicc export HDF5_MPI="ON" pip install --no-binary=h5py h5py -pip install --ignore-installed netcdf4==1.5.4 +# pip install --ignore-installed netcdf4==1.5.4 python -m pip install "dask[complete]" diff --git a/HPC_setup/setup_venv_juwels.sh b/HPC_setup/setup_venv_juwels.sh index ba44900ee2db3e3cde63b4d38c05e643eb154d5c..242ff376b92bac0b77b3224c9703d67c2a6f4136 100755 --- a/HPC_setup/setup_venv_juwels.sh +++ b/HPC_setup/setup_venv_juwels.sh @@ -31,7 +31,7 @@ echo "##### FINISH INSTALLING requirements_JUWELS_additionals.txt #####" # pip install --ignore-installed matplotlib==3.2.0 # pip install --ignore-installed pandas==1.0.1 -pip install -U typing_extensions +# pip install -U typing_extensions python -m pip install --ignore-installed "dask[complete]==2021.3.0" # Comment: Maybe we have to export PYTHONPATH a second time ater activating the venv (after job allocation) diff --git a/README.md b/README.md index 0e1df0561d15b743a85b0981b552a1444b6cc38c..a5fce2e53d82e3cff75a4f61000c616c62cbec69 100644 --- a/README.md +++ b/README.md @@ -25,13 +25,15 @@ HPC systems, see [here](#special-instructions-for-installation-on-jülich-hpc-sy * Install all **requirements** from [`requirements.txt`](https://gitlab.version.fz-juelich.de/toar/mlair/-/blob/master/requirements.txt) preferably in a virtual environment. You can use `pip install -r requirements.txt` to install all requirements at once. Note, we recently updated the version of Cartopy and there seems to be an ongoing - [issue](https://github.com/SciTools/cartopy/issues/1552) when installing numpy and Cartopy at the same time. If you - run into trouble, you could use `cat requirements.txt | cut -f1 -d"#" | sed '/^\s*$/d' | xargs -L 1 pip install` - instead. + [issue](https://github.com/SciTools/cartopy/issues/1552) when installing **numpy** and **Cartopy** at the same time. + If you run into trouble, you could use + `cat requirements.txt | cut -f1 -d"#" | sed '/^\s*$/d' | xargs -L 1 pip install` instead or first install numpy with + `pip install numpy==<version_from_reqs>` followed be the default installation of requirements. For the latter, you can + also use `grep numpy requirements.txt | xargs pip install`. * Installation of **MLAir**: * Either clone MLAir from the [gitlab repository](https://gitlab.version.fz-juelich.de/toar/mlair.git) and use it without installation (beside the requirements) - * or download the distribution file ([current version](https://gitlab.version.fz-juelich.de/toar/mlair/-/blob/master/dist/mlair-1.4.0-py3-none-any.whl)) + * or download the distribution file ([current version](https://gitlab.version.fz-juelich.de/toar/mlair/-/blob/master/dist/mlair-1.5.0-py3-none-any.whl)) and install it via `pip install <dist_file>.whl`. In this case, you can simply import MLAir in any python script inside your virtual environment using `import mlair`. * (tf) Currently, TensorFlow-1.13 is mentioned in the requirements. We already tested the TensorFlow-1.15 version and couldn't diff --git a/dist/mlair-1.5.0-py3-none-any.whl b/dist/mlair-1.5.0-py3-none-any.whl new file mode 100644 index 0000000000000000000000000000000000000000..34495d960b009737fb40bec6dfe3a96effd14c02 Binary files /dev/null and b/dist/mlair-1.5.0-py3-none-any.whl differ diff --git a/docs/_source/installation.rst b/docs/_source/installation.rst index 27543ac109439e487756cc211ecc47be946c586c..c87e64b217b4207185cfc662fdf00d2f7e891cc5 100644 --- a/docs/_source/installation.rst +++ b/docs/_source/installation.rst @@ -26,7 +26,7 @@ Installation of MLAir * Install all requirements from `requirements.txt <https://gitlab.version.fz-juelich.de/toar/machinelearningtools/-/blob/master/requirements.txt>`_ preferably in a virtual environment * Either clone MLAir from the `gitlab repository <https://gitlab.version.fz-juelich.de/toar/machinelearningtools.git>`_ -* or download the distribution file (`current version <https://gitlab.version.fz-juelich.de/toar/mlair/-/blob/master/dist/mlair-1.4.0-py3-none-any.whl>`_) +* or download the distribution file (`current version <https://gitlab.version.fz-juelich.de/toar/mlair/-/blob/master/dist/mlair-1.5.0-py3-none-any.whl>`_) and install it via :py:`pip install <dist_file>.whl`. In this case, you can simply import MLAir in any python script inside your virtual environment using :py:`import mlair`. * (tf) Currently, TensorFlow-1.13 is mentioned in the requirements. We already tested the TensorFlow-1.15 version and couldn't diff --git a/docs/requirements_docs.txt b/docs/requirements_docs.txt index a39acca8a7cf887237f595d7992960ac10233a85..ee455d83f0debc10faa09ffd82cad9a77930d936 100644 --- a/docs/requirements_docs.txt +++ b/docs/requirements_docs.txt @@ -1,6 +1,9 @@ sphinx==3.0.3 -sphinx-autoapi==1.3.0 -sphinx-autodoc-typehints==1.10.3 +sphinx-autoapi==1.8.4 +sphinx-autodoc-typehints==1.12.0 sphinx-rtd-theme==0.4.3 #recommonmark==0.6.0 -m2r2==0.2.5 \ No newline at end of file +m2r2==0.3.1 +docutils<0.18 +mistune==0.8.4 +setuptools>=59.5.0 \ No newline at end of file diff --git a/mlair/__init__.py b/mlair/__init__.py index f760f9b0fa4b87bde1f6ee409626f4428083d895..75359e1773edea55ecc47556a83a465510fac6c8 100644 --- a/mlair/__init__.py +++ b/mlair/__init__.py @@ -1,6 +1,6 @@ __version_info__ = { 'major': 1, - 'minor': 4, + 'minor': 5, 'micro': 0, } diff --git a/mlair/configuration/defaults.py b/mlair/configuration/defaults.py index 4d762581ecc64ba4f4335d77a48bd1c1d913e01c..0e529253ef7787b7c716b6813b3054a29821b989 100644 --- a/mlair/configuration/defaults.py +++ b/mlair/configuration/defaults.py @@ -55,7 +55,7 @@ DEFAULT_FEATURE_IMPORTANCE_N_BOOTS = 20 DEFAULT_FEATURE_IMPORTANCE_BOOTSTRAP_TYPE = "singleinput" DEFAULT_FEATURE_IMPORTANCE_BOOTSTRAP_METHOD = "shuffle" DEFAULT_PLOT_LIST = ["PlotMonthlySummary", "PlotStationMap", "PlotClimatologicalSkillScore", "PlotTimeSeries", - "PlotCompetitiveSkillScore", "PlotBootstrapSkillScore", "PlotConditionalQuantiles", + "PlotCompetitiveSkillScore", "PlotFeatureImportanceSkillScore", "PlotConditionalQuantiles", "PlotAvailability", "PlotAvailabilityHistogram", "PlotDataHistogram", "PlotPeriodogram", "PlotSampleUncertaintyFromBootstrap"] DEFAULT_SAMPLING = "daily" diff --git a/mlair/data_handler/abstract_data_handler.py b/mlair/data_handler/abstract_data_handler.py index a6f49d2d756200a8a4e5f15c13ecf385ef08dabc..aafdb80c3455d2f659e5a952d81c1a0b841eea2e 100644 --- a/mlair/data_handler/abstract_data_handler.py +++ b/mlair/data_handler/abstract_data_handler.py @@ -6,16 +6,17 @@ import inspect import logging from typing import Union, Dict + import dask from dask.diagnostics import ProgressBar - -from mlair.helpers import remove_items +from mlair.helpers import remove_items, to_list -class AbstractDataHandler: +class AbstractDataHandler(object): _requirements = [] _store_attributes = [] + _skip_args = ["self"] def __init__(self, *args, **kwargs): pass @@ -26,16 +27,28 @@ class AbstractDataHandler: return cls(*args, **kwargs) @classmethod - def requirements(cls): + def requirements(cls, skip_args=None): """Return requirements and own arguments without duplicates.""" - return list(set(cls._requirements + cls.own_args())) + skip_args = cls._skip_args if skip_args is None else cls._skip_args + to_list(skip_args) + return remove_items(list(set(cls._requirements + cls.own_args())), skip_args) @classmethod def own_args(cls, *args): """Return all arguments (including kwonlyargs).""" arg_spec = inspect.getfullargspec(cls) - list_of_args = arg_spec.args + arg_spec.kwonlyargs - return remove_items(list_of_args, ["self"] + list(args)) + list_of_args = arg_spec.args + arg_spec.kwonlyargs + cls.super_args() + return list(set(remove_items(list_of_args, list(args)))) + + @classmethod + def super_args(cls): + args = [] + for super_cls in cls.__mro__: + if super_cls == cls: + continue + if hasattr(super_cls, "own_args"): + # args.extend(super_cls.own_args()) + args.extend(getattr(super_cls, "own_args")()) + return list(set(args)) @classmethod def store_attributes(cls) -> list: @@ -86,6 +99,9 @@ class AbstractDataHandler: """Return coordinates as dictionary with keys `lon` and `lat`.""" return None + def get_wind_upstream_sector_by_name(self): + raise NotImplementedError + def _hash_list(self): return [] diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py index 5aefb0368ec1cf544443bb5e0412dd16a97f2a7f..3de749d02375243269f9eb51c08400840fd0656a 100644 --- a/mlair/data_handler/data_handler_mixed_sampling.py +++ b/mlair/data_handler/data_handler_mixed_sampling.py @@ -2,30 +2,25 @@ __author__ = 'Lukas Leufen' __date__ = '2020-11-05' from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation -from mlair.data_handler.data_handler_with_filter import DataHandlerKzFilterSingleStation, \ - DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, DataHandlerClimateFirFilterSingleStation -from mlair.data_handler.data_handler_with_filter import DataHandlerClimateFirFilter, DataHandlerFirFilter, \ - DataHandlerKzFilter +from mlair.data_handler.data_handler_with_filter import DataHandlerFirFilterSingleStation, \ + DataHandlerFilterSingleStation, DataHandlerClimateFirFilterSingleStation +from mlair.data_handler.data_handler_with_filter import DataHandlerClimateFirFilter, DataHandlerFirFilter from mlair.data_handler import DefaultDataHandler from mlair import helpers -from mlair.helpers import remove_items +from mlair.helpers import to_list from mlair.configuration.defaults import DEFAULT_SAMPLING, DEFAULT_INTERPOLATION_LIMIT, DEFAULT_INTERPOLATION_METHOD from mlair.helpers.filter import filter_width_kzf import copy -import inspect -from typing import Callable import datetime as dt from typing import Any from functools import partial -import numpy as np import pandas as pd import xarray as xr class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation): - _requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"]) def __init__(self, *args, **kwargs): """ @@ -101,9 +96,6 @@ class DataHandlerMixedSampling(DefaultDataHandler): class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSingleStation, DataHandlerFilterSingleStation): - _requirements1 = DataHandlerFilterSingleStation.requirements() - _requirements2 = DataHandlerMixedSamplingSingleStation.requirements() - _requirements = list(set(_requirements1 + _requirements2)) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -111,6 +103,16 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi def _check_sampling(self, **kwargs): assert kwargs.get("sampling") == ("hourly", "daily") + def apply_filter(self): + raise NotImplementedError + + def create_filter_index(self) -> pd.Index: + """Create name for filter dimension.""" + raise NotImplementedError + + def _create_lazy_data(self): + raise NotImplementedError + def make_input_target(self): """ A FIR filter is applied on the input data that has hourly resolution. Lables Y are provided as aggregated values @@ -159,46 +161,31 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi self.target_data = self._slice_prep(_target_data, self.start, self.end) -class DataHandlerMixedSamplingWithKzFilterSingleStation(DataHandlerMixedSamplingWithFilterSingleStation, - DataHandlerKzFilterSingleStation): - _requirements1 = DataHandlerKzFilterSingleStation.requirements() - _requirements2 = DataHandlerMixedSamplingWithFilterSingleStation.requirements() - _requirements = list(set(_requirements1 + _requirements2)) - - def estimate_filter_width(self): - """ - f = 0.5 / (len * sqrt(itr)) -> T = 1 / f - :return: - """ - return int(self.kz_filter_length[0] * np.sqrt(self.kz_filter_iter[0]) * 2) - - def _extract_lazy(self, lazy_data): - _data, _meta, _input_data, _target_data, self.cutoff_period, self.cutoff_period_days, \ - self.filter_dim_order = lazy_data - super(__class__, self)._extract_lazy((_data, _meta, _input_data, _target_data)) - - -class DataHandlerMixedSamplingWithKzFilter(DataHandlerKzFilter): - """Data handler using mixed sampling for input and target. Inputs are temporal filtered.""" - - data_handler = DataHandlerMixedSamplingWithKzFilterSingleStation - data_handler_transformation = DataHandlerMixedSamplingWithKzFilterSingleStation - _requirements = data_handler.requirements() - - class DataHandlerMixedSamplingWithFirFilterSingleStation(DataHandlerMixedSamplingWithFilterSingleStation, DataHandlerFirFilterSingleStation): - _requirements1 = DataHandlerFirFilterSingleStation.requirements() - _requirements2 = DataHandlerMixedSamplingWithFilterSingleStation.requirements() - _requirements = list(set(_requirements1 + _requirements2)) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) def estimate_filter_width(self): """Filter width is determined by the filter with the highest order.""" - return max(self.filter_order) + if isinstance(self.filter_order[0], tuple): + return max([filter_width_kzf(*e) for e in self.filter_order]) + else: + return max(self.filter_order) + + def apply_filter(self): + DataHandlerFirFilterSingleStation.apply_filter(self) + + def create_filter_index(self, add_unfiltered_index=True) -> pd.Index: + return DataHandlerFirFilterSingleStation.create_filter_index(self, add_unfiltered_index=add_unfiltered_index) def _extract_lazy(self, lazy_data): _data, _meta, _input_data, _target_data, self.fir_coeff, self.filter_dim_order = lazy_data - super(__class__, self)._extract_lazy((_data, _meta, _input_data, _target_data)) + DataHandlerMixedSamplingWithFilterSingleStation._extract_lazy(self, (_data, _meta, _input_data, _target_data)) + + def _create_lazy_data(self): + return DataHandlerFirFilterSingleStation._create_lazy_data(self) @staticmethod def _get_fs(**kwargs): @@ -220,18 +207,8 @@ class DataHandlerMixedSamplingWithFirFilter(DataHandlerFirFilter): _requirements = data_handler.requirements() -class DataHandlerMixedSamplingWithClimateFirFilterSingleStation(DataHandlerMixedSamplingWithFilterSingleStation, - DataHandlerClimateFirFilterSingleStation): - _requirements1 = DataHandlerClimateFirFilterSingleStation.requirements() - _requirements2 = DataHandlerMixedSamplingWithFilterSingleStation.requirements() - _requirements = list(set(_requirements1 + _requirements2)) - - def estimate_filter_width(self): - """Filter width is determined by the filter with the highest order.""" - if isinstance(self.filter_order[0], tuple): - return max([filter_width_kzf(*e) for e in self.filter_order]) - else: - return max(self.filter_order) +class DataHandlerMixedSamplingWithClimateFirFilterSingleStation(DataHandlerClimateFirFilterSingleStation, + DataHandlerMixedSamplingWithFirFilterSingleStation): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -241,17 +218,6 @@ class DataHandlerMixedSamplingWithClimateFirFilterSingleStation(DataHandlerMixed self.filter_dim_order = lazy_data DataHandlerMixedSamplingWithFilterSingleStation._extract_lazy(self, (_data, _meta, _input_data, _target_data)) - @staticmethod - def _get_fs(**kwargs): - """Return frequency in 1/day (not Hz)""" - sampling = kwargs.get("sampling")[0] - if sampling == "daily": - return 1 - elif sampling == "hourly": - return 24 - else: - raise ValueError(f"Unknown sampling rate {sampling}. Only daily and hourly resolution is supported.") - class DataHandlerMixedSamplingWithClimateFirFilter(DataHandlerClimateFirFilter): """Data handler using mixed sampling for input and target. Inputs are temporal filtered.""" @@ -268,19 +234,11 @@ class DataHandlerMixedSamplingWithClimateFirFilter(DataHandlerClimateFirFilter): self.filter_add_unfiltered = filter_add_unfiltered super().__init__(*args, **kwargs) - @classmethod - def own_args(cls, *args): - """Return all arguments (including kwonlyargs).""" - super_own_args = DataHandlerClimateFirFilter.own_args(*args) - arg_spec = inspect.getfullargspec(cls) - list_of_args = arg_spec.args + arg_spec.kwonlyargs + super_own_args - return remove_items(list_of_args, ["self"] + list(args)) - def _create_collection(self): + collection = super()._create_collection() if self.filter_add_unfiltered is True and self.dh_unfiltered is not None: - return [self.id_class, self.dh_unfiltered] - else: - return super()._create_collection() + collection.append(self.dh_unfiltered) + return collection @classmethod def build(cls, station: str, **kwargs): @@ -306,19 +264,23 @@ class DataHandlerMixedSamplingWithClimateFirFilter(DataHandlerClimateFirFilter): return kwargs_dict @classmethod - def transformation(cls, set_stations, tmp_path=None, **kwargs): + def transformation(cls, set_stations, tmp_path=None, dh_transformation=None, **kwargs): - sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls._requirements if k in kwargs} - if "transformation" not in sp_keys.keys(): + # sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls._requirements if k in kwargs} + if "transformation" not in kwargs.keys(): return + if dh_transformation is None: + dh_transformation = (cls.data_handler_transformation, cls.data_handler_unfiltered) + elif not isinstance(dh_transformation, tuple): + dh_transformation = (dh_transformation, dh_transformation) transformation_filtered = super().transformation(set_stations, tmp_path=tmp_path, - dh_transformation=cls.data_handler_transformation, **kwargs) + dh_transformation=dh_transformation[0], **kwargs) if kwargs.get("filter_add_unfiltered", False) is False: return transformation_filtered else: transformation_unfiltered = super().transformation(set_stations, tmp_path=tmp_path, - dh_transformation=cls.data_handler_unfiltered, **kwargs) + dh_transformation=dh_transformation[1], **kwargs) return {"filtered": transformation_filtered, "unfiltered": transformation_unfiltered} def get_X_original(self): @@ -337,80 +299,228 @@ class DataHandlerMixedSamplingWithClimateFirFilter(DataHandlerClimateFirFilter): return super().get_X_original() -class DataHandlerSeparationOfScalesSingleStation(DataHandlerMixedSamplingWithKzFilterSingleStation): - """ - Data handler using mixed sampling for input and target. Inputs are temporal filtered and depending on the - separation frequency of a filtered time series the time step delta for input data is adjusted (see image below). +class DataHandlerMixedSamplingWithClimateAndFirFilter(DataHandlerMixedSamplingWithClimateFirFilter): + # data_handler = DataHandlerMixedSamplingWithClimateFirFilterSingleStation + # data_handler_transformation = DataHandlerMixedSamplingWithClimateFirFilterSingleStation + # data_handler_unfiltered = DataHandlerMixedSamplingSingleStation + # _requirements = list(set(data_handler.requirements() + data_handler_unfiltered.requirements())) + # DEFAULT_FILTER_ADD_UNFILTERED = False + data_handler_climate_fir = DataHandlerMixedSamplingWithClimateFirFilterSingleStation + data_handler_fir = DataHandlerMixedSamplingWithFirFilterSingleStation + data_handler = None + data_handler_unfiltered = DataHandlerMixedSamplingSingleStation + _requirements = list(set(data_handler_climate_fir.requirements() + data_handler_fir.requirements() + + data_handler_unfiltered.requirements())) - .. image:: ../../../../../_source/_plots/separation_of_scales.png - :width: 400 + def __init__(self, data_handler_class_chem, data_handler_class_meteo, data_handler_class_chem_unfiltered, + data_handler_class_meteo_unfiltered, chem_vars, meteo_vars, *args, **kwargs): - """ + if len(chem_vars) > 0: + id_class, id_class_unfiltered = data_handler_class_chem, data_handler_class_chem_unfiltered + self.id_class_other = data_handler_class_meteo + self.id_class_other_unfiltered = data_handler_class_meteo_unfiltered + else: + id_class, id_class_unfiltered = data_handler_class_meteo, data_handler_class_meteo_unfiltered + self.id_class_other = data_handler_class_chem + self.id_class_other_unfiltered = data_handler_class_chem_unfiltered + super().__init__(id_class, *args, data_handler_class_unfiltered=id_class_unfiltered, **kwargs) - _requirements = DataHandlerMixedSamplingWithKzFilterSingleStation.requirements() - _hash = DataHandlerMixedSamplingWithKzFilterSingleStation._hash + ["time_delta"] + @classmethod + def _split_chem_and_meteo_variables(cls, **kwargs): + if "variables" in kwargs: + variables = kwargs.get("variables") + elif "statistics_per_var" in kwargs: + variables = kwargs.get("statistics_per_var") + else: + variables = None + if variables is None: + variables = cls.data_handler_climate_fir.DEFAULT_VAR_ALL_DICT.keys() + chem_vars = cls.data_handler_climate_fir.chem_vars + chem = set(variables).intersection(chem_vars) + meteo = set(variables).difference(chem_vars) + return to_list(chem), to_list(meteo) - def __init__(self, *args, time_delta=np.sqrt, **kwargs): - assert isinstance(time_delta, Callable) - self.time_delta = time_delta - super().__init__(*args, **kwargs) + @classmethod + def build(cls, station: str, **kwargs): + chem_vars, meteo_vars = cls._split_chem_and_meteo_variables(**kwargs) + filter_add_unfiltered = kwargs.get("filter_add_unfiltered", False) + sp_chem, sp_chem_unfiltered = None, None + sp_meteo, sp_meteo_unfiltered = None, None + + if len(chem_vars) > 0: + sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls.data_handler_climate_fir.requirements() if k in kwargs} + sp_keys = cls.build_update_kwargs(sp_keys, dh_type="filtered_chem") + sp_keys.update({"variables": chem_vars}) + cls.adjust_window_opts("chem", "window_history_size", sp_keys) + cls.adjust_window_opts("chem", "window_history_offset", sp_keys) + sp_chem = cls.data_handler_climate_fir(station, **sp_keys) + if filter_add_unfiltered is True: + sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls.data_handler_unfiltered.requirements() if k in kwargs} + sp_keys = cls.build_update_kwargs(sp_keys, dh_type="unfiltered_chem") + sp_keys.update({"variables": chem_vars}) + cls.adjust_window_opts("chem", "window_history_size", sp_keys) + cls.adjust_window_opts("chem", "window_history_offset", sp_keys) + sp_chem_unfiltered = cls.data_handler_unfiltered(station, **sp_keys) + if len(meteo_vars) > 0: + sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls.data_handler_fir.requirements() if k in kwargs} + sp_keys = cls.build_update_kwargs(sp_keys, dh_type="filtered_meteo") + sp_keys.update({"variables": meteo_vars}) + cls.adjust_window_opts("meteo", "window_history_size", sp_keys) + cls.adjust_window_opts("meteo", "window_history_offset", sp_keys) + sp_meteo = cls.data_handler_fir(station, **sp_keys) + if filter_add_unfiltered is True: + sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls.data_handler_unfiltered.requirements() if k in kwargs} + sp_keys = cls.build_update_kwargs(sp_keys, dh_type="unfiltered_meteo") + sp_keys.update({"variables": meteo_vars}) + cls.adjust_window_opts("meteo", "window_history_size", sp_keys) + cls.adjust_window_opts("meteo", "window_history_offset", sp_keys) + sp_meteo_unfiltered = cls.data_handler_unfiltered(station, **sp_keys) - def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None: - """ - Create a xr.DataArray containing history data. + dp_args = {k: copy.deepcopy(kwargs[k]) for k in cls.own_args("id_class") if k in kwargs} + return cls(sp_chem, sp_meteo, sp_chem_unfiltered, sp_meteo_unfiltered, chem_vars, meteo_vars, **dp_args) - Shift the data window+1 times and return a xarray which has a new dimension 'window' containing the shifted - data. This is used to represent history in the data. Results are stored in history attribute. + @staticmethod + def adjust_window_opts(key: str, parameter_name: str, kwargs: dict): + if parameter_name in kwargs: + window_opt = kwargs.pop(parameter_name) + if isinstance(window_opt, dict): + window_opt = window_opt[key] + kwargs[parameter_name] = window_opt - :param dim_name_of_inputs: Name of dimension which contains the input variables - :param window: number of time steps to look back in history - Note: window will be treated as negative value. This should be in agreement with looking back on - a time line. Nonetheless positive values are allowed but they are converted to its negative - expression - :param dim_name_of_shift: Dimension along shift will be applied - """ - window = -abs(window) - data = self.input_data - self.history = self.stride(data, dim_name_of_shift, window, offset=self.window_history_offset) - - def stride(self, data: xr.DataArray, dim: str, window: int, offset: int = 0) -> xr.DataArray: - - # this is just a code snippet to check the results of the kz filter - # import matplotlib - # matplotlib.use("TkAgg") - # import matplotlib.pyplot as plt - # xr.concat(res, dim="filter").sel({"variables":"temp", "Stations":"DEBW107", "datetime":"2010-01-01T00:00:00"}).plot.line(hue="filter") - - time_deltas = np.round(self.time_delta(self.cutoff_period)).astype(int) - start, end = window, 1 - res = [] - _range = list(map(lambda x: x + offset, range(start, end))) - window_array = self.create_index_array(self.window_dim, _range, squeeze_dim=self.target_dim) - for delta, filter_name in zip(np.append(time_deltas, 1), data.coords["filter"]): - res_filter = [] - data_filter = data.sel({"filter": filter_name}) - for w in _range: - res_filter.append(data_filter.shift({dim: -(w - offset) * delta - offset})) - res_filter = xr.concat(res_filter, dim=window_array).chunk() - res.append(res_filter) - res = xr.concat(res, dim="filter").compute() - return res + def _create_collection(self): + collection = super()._create_collection() + if self.id_class_other is not None: + collection.append(self.id_class_other) + if self.filter_add_unfiltered is True and self.id_class_other_unfiltered is not None: + collection.append(self.id_class_other_unfiltered) + return collection - def estimate_filter_width(self): - """ - Attention: this method returns the maximum value of - * either estimated filter width f = 0.5 / (len * sqrt(itr)) -> T = 1 / f or - * time delta method applied on the estimated filter width mupliplied by window_history_size - to provide a sufficiently wide filter width. - """ - est = self.kz_filter_length[0] * np.sqrt(self.kz_filter_iter[0]) * 2 - return int(max([self.time_delta(est) * self.window_history_size, est])) + @classmethod + def transformation(cls, set_stations, tmp_path=None, **kwargs): + if "transformation" not in kwargs.keys(): + return -class DataHandlerSeparationOfScales(DefaultDataHandler): - """Data handler using mixed sampling for input and target. Inputs are temporal filtered and different time step - sizes are applied in relation to frequencies.""" + chem_vars, meteo_vars = cls._split_chem_and_meteo_variables(**kwargs) + + # chem transformation + kwargs_chem = copy.deepcopy(kwargs) + kwargs_chem["variables"] = chem_vars + cls.adjust_window_opts("chem", "window_history_size", kwargs_chem) + cls.adjust_window_opts("chem", "window_history_offset", kwargs_chem) + dh_transformation = (cls.data_handler_climate_fir, cls.data_handler_unfiltered) + transformation_chem = super().transformation(set_stations, tmp_path=tmp_path, + dh_transformation=dh_transformation, **kwargs_chem) + + # meteo transformation + kwargs_meteo = copy.deepcopy(kwargs) + kwargs_meteo["variables"] = meteo_vars + cls.adjust_window_opts("meteo", "window_history_size", kwargs_meteo) + cls.adjust_window_opts("meteo", "window_history_offset", kwargs_meteo) + dh_transformation = (cls.data_handler_fir, cls.data_handler_unfiltered) + transformation_meteo = super().transformation(set_stations, tmp_path=tmp_path, + dh_transformation=dh_transformation, **kwargs_meteo) + + # combine all transformations + transformation_res = {} + if isinstance(transformation_chem, dict): + if len(transformation_chem) > 0: + transformation_res["filtered_chem"] = transformation_chem.pop("filtered") + transformation_res["unfiltered_chem"] = transformation_chem.pop("unfiltered") + else: # if no unfiltered chem branch + transformation_res["filtered_chem"] = transformation_chem + if isinstance(transformation_meteo, dict): + if len(transformation_meteo) > 0: + transformation_res["filtered_meteo"] = transformation_meteo.pop("filtered") + transformation_res["unfiltered_meteo"] = transformation_meteo.pop("unfiltered") + else: # if no unfiltered meteo branch + transformation_res["filtered_meteo"] = transformation_meteo + return transformation_res if len(transformation_res) > 0 else None - data_handler = DataHandlerSeparationOfScalesSingleStation - data_handler_transformation = DataHandlerSeparationOfScalesSingleStation - _requirements = data_handler.requirements() + def get_X_original(self): + if self.use_filter_branches is True: + X = [] + for data in self._collection: + if hasattr(data, "filter_dim"): + X_total = data.get_X() + filter_dim = data.filter_dim + for filter_name in data.filter_dim_order: + X.append(X_total.sel({filter_dim: filter_name}, drop=True)) + else: + X.append(data.get_X()) + return X + else: + return super().get_X_original() + + +# +# class DataHandlerSeparationOfScalesSingleStation(DataHandlerMixedSamplingWithKzFilterSingleStation): +# """ +# Data handler using mixed sampling for input and target. Inputs are temporal filtered and depending on the +# separation frequency of a filtered time series the time step delta for input data is adjusted (see image below). +# +# .. image:: ../../../../../_source/_plots/separation_of_scales.png +# :width: 400 +# +# """ +# +# _requirements = DataHandlerMixedSamplingWithKzFilterSingleStation.requirements() +# _hash = DataHandlerMixedSamplingWithKzFilterSingleStation._hash + ["time_delta"] +# +# def __init__(self, *args, time_delta=np.sqrt, **kwargs): +# assert isinstance(time_delta, Callable) +# self.time_delta = time_delta +# super().__init__(*args, **kwargs) +# +# def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None: +# """ +# Create a xr.DataArray containing history data. +# +# Shift the data window+1 times and return a xarray which has a new dimension 'window' containing the shifted +# data. This is used to represent history in the data. Results are stored in history attribute. +# +# :param dim_name_of_inputs: Name of dimension which contains the input variables +# :param window: number of time steps to look back in history +# Note: window will be treated as negative value. This should be in agreement with looking back on +# a time line. Nonetheless positive values are allowed but they are converted to its negative +# expression +# :param dim_name_of_shift: Dimension along shift will be applied +# """ +# window = -abs(window) +# data = self.input_data +# self.history = self.stride(data, dim_name_of_shift, window, offset=self.window_history_offset) +# +# def stride(self, data: xr.DataArray, dim: str, window: int, offset: int = 0) -> xr.DataArray: +# time_deltas = np.round(self.time_delta(self.cutoff_period)).astype(int) +# start, end = window, 1 +# res = [] +# _range = list(map(lambda x: x + offset, range(start, end))) +# window_array = self.create_index_array(self.window_dim, _range, squeeze_dim=self.target_dim) +# for delta, filter_name in zip(np.append(time_deltas, 1), data.coords["filter"]): +# res_filter = [] +# data_filter = data.sel({"filter": filter_name}) +# for w in _range: +# res_filter.append(data_filter.shift({dim: -(w - offset) * delta - offset})) +# res_filter = xr.concat(res_filter, dim=window_array).chunk() +# res.append(res_filter) +# res = xr.concat(res, dim="filter").compute() +# return res +# +# def estimate_filter_width(self): +# """ +# Attention: this method returns the maximum value of +# * either estimated filter width f = 0.5 / (len * sqrt(itr)) -> T = 1 / f or +# * time delta method applied on the estimated filter width mupliplied by window_history_size +# to provide a sufficiently wide filter width. +# """ +# est = self.kz_filter_length[0] * np.sqrt(self.kz_filter_iter[0]) * 2 +# return int(max([self.time_delta(est) * self.window_history_size, est])) +# +# +# class DataHandlerSeparationOfScales(DefaultDataHandler): +# """Data handler using mixed sampling for input and target. Inputs are temporal filtered and different time step +# sizes are applied in relation to frequencies.""" +# +# data_handler = DataHandlerSeparationOfScalesSingleStation +# data_handler_transformation = DataHandlerSeparationOfScalesSingleStation +# _requirements = data_handler.requirements() diff --git a/mlair/data_handler/data_handler_single_station.py b/mlair/data_handler/data_handler_single_station.py index 628e6b44aef975c67c468d1f1c27c6588bc701b6..429f9604593539e6060716c37b4c9736b6beed40 100644 --- a/mlair/data_handler/data_handler_single_station.py +++ b/mlair/data_handler/data_handler_single_station.py @@ -48,12 +48,14 @@ class DataHandlerSingleStation(AbstractDataHandler): DEFAULT_SAMPLING = "daily" DEFAULT_INTERPOLATION_LIMIT = 0 DEFAULT_INTERPOLATION_METHOD = "linear" + chem_vars = ["benzene", "ch4", "co", "ethane", "no", "no2", "nox", "o3", "ox", "pm1", "pm10", "pm2p5", "propane", + "so2", "toluene"] _hash = ["station", "statistics_per_var", "data_origin", "station_type", "network", "sampling", "target_dim", "target_var", "time_dim", "iter_dim", "window_dim", "window_history_size", "window_history_offset", "window_lead_time", "interpolation_limit", "interpolation_method"] - def __init__(self, station, data_path, statistics_per_var, station_type=DEFAULT_STATION_TYPE, + def __init__(self, station, data_path, statistics_per_var=None, station_type=DEFAULT_STATION_TYPE, network=DEFAULT_NETWORK, sampling: Union[str, Tuple[str]] = DEFAULT_SAMPLING, target_dim=DEFAULT_TARGET_DIM, target_var=DEFAULT_TARGET_VAR, time_dim=DEFAULT_TIME_DIM, iter_dim=DEFAULT_ITER_DIM, window_dim=DEFAULT_WINDOW_DIM, @@ -73,7 +75,7 @@ class DataHandlerSingleStation(AbstractDataHandler): if self.lazy is True: self.lazy_path = os.path.join(data_path, "lazy_data", self.__class__.__name__) check_path_and_create(self.lazy_path) - self.statistics_per_var = statistics_per_var + self.statistics_per_var = statistics_per_var or self.DEFAULT_VAR_ALL_DICT self.data_origin = data_origin self.do_transformation = transformation is not None self.input_data, self.target_data = None, None @@ -276,6 +278,7 @@ class DataHandlerSingleStation(AbstractDataHandler): filename = os.path.join(self.lazy_path, hash + ".pickle") try: if self.overwrite_lazy_data is True: + os.remove(filename) raise FileNotFoundError with open(filename, "rb") as pickle_file: lazy_data = dill.load(pickle_file) @@ -417,9 +420,7 @@ class DataHandlerSingleStation(AbstractDataHandler): :return: corrected data """ - chem_vars = ["benzene", "ch4", "co", "ethane", "no", "no2", "nox", "o3", "ox", "pm1", "pm10", "pm2p5", - "propane", "so2", "toluene"] - used_chem_vars = list(set(chem_vars) & set(data.coords[self.target_dim].values)) + used_chem_vars = list(set(self.chem_vars) & set(data.coords[self.target_dim].values)) if len(used_chem_vars) > 0: data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum) return data @@ -615,6 +616,12 @@ class DataHandlerSingleStation(AbstractDataHandler): non_nan_history = self.history.dropna(dim=dim) non_nan_label = self.label.dropna(dim=dim) non_nan_observation = self.observation.dropna(dim=dim) + if non_nan_label.coords[dim].shape[0] == 0: + raise ValueError(f'self.label consist of NaNs only - station {self.station} is therefore dropped') + if non_nan_history.coords[dim].shape[0] == 0: + raise ValueError(f'self.history consist of NaNs only - station {self.station} is therefore dropped') + if non_nan_observation.coords[dim].shape[0] == 0: + raise ValueError(f'self.observation consist of NaNs only - station {self.station} is therefore dropped') intersect = reduce(np.intersect1d, (non_nan_history.coords[dim].values, non_nan_label.coords[dim].values, non_nan_observation.coords[dim].values)) @@ -771,24 +778,3 @@ class DataHandlerSingleStation(AbstractDataHandler): def _get_hash(self): hash = "".join([str(self.__getattribute__(e)) for e in self._hash_list()]).encode() return hashlib.md5(hash).hexdigest() - - -if __name__ == "__main__": - statistics_per_var = {'o3': 'dma8eu', 'temp-rea-miub': 'maximum'} - sp = DataHandlerSingleStation(data_path='/home/felix/PycharmProjects/mlt_new/data/', station='DEBY122', - statistics_per_var=statistics_per_var, station_type='background', - network='UBA', sampling='daily', target_dim='variables', target_var='o3', - time_dim='datetime', window_history_size=7, window_lead_time=3, - interpolation_limit=0 - ) # transformation={'method': 'standardise'}) - sp2 = DataHandlerSingleStation(data_path='/home/felix/PycharmProjects/mlt_new/data/', station='DEBY122', - statistics_per_var=statistics_per_var, station_type='background', - network='UBA', sampling='daily', target_dim='variables', target_var='o3', - time_dim='datetime', window_history_size=7, window_lead_time=3, - transformation={'method': 'standardise'}) - sp2.transform(inverse=True) - sp.get_X() - sp.get_Y() - print(len(sp)) - print(sp.shape) - print(sp) diff --git a/mlair/data_handler/data_handler_with_filter.py b/mlair/data_handler/data_handler_with_filter.py index 4707fd580562a68fd6b2dc0843551905e70d7e50..07fdc41fc4dae49bd44a071dd2228c4bff860b04 100644 --- a/mlair/data_handler/data_handler_with_filter.py +++ b/mlair/data_handler/data_handler_with_filter.py @@ -3,7 +3,6 @@ __author__ = 'Lukas Leufen' __date__ = '2020-08-26' -import inspect import copy import numpy as np import pandas as pd @@ -13,8 +12,7 @@ from functools import partial import logging from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation from mlair.data_handler import DefaultDataHandler -from mlair.helpers import remove_items, to_list, TimeTrackingWrapper, statistics -from mlair.helpers.filter import KolmogorovZurbenkoFilterMovingWindow as KZFilter +from mlair.helpers import to_list, TimeTrackingWrapper, statistics from mlair.helpers.filter import FIRFilter, ClimateFIRFilter, omega_null_kzf # define a more general date type for type hinting @@ -40,7 +38,6 @@ str_or_list = Union[str, List[str]] class DataHandlerFilterSingleStation(DataHandlerSingleStation): """General data handler for a single station to be used by a superior data handler.""" - _requirements = remove_items(DataHandlerSingleStation.requirements(), "station") _hash = DataHandlerSingleStation._hash + ["filter_dim"] DEFAULT_FILTER_DIM = "filter" @@ -119,24 +116,15 @@ class DataHandlerFilter(DefaultDataHandler): self.use_filter_branches = use_filter_branches super().__init__(*args, **kwargs) - @classmethod - def own_args(cls, *args): - """Return all arguments (including kwonlyargs).""" - super_own_args = DefaultDataHandler.own_args(*args) - arg_spec = inspect.getfullargspec(cls) - list_of_args = arg_spec.args + arg_spec.kwonlyargs + super_own_args - return remove_items(list_of_args, ["self"] + list(args)) - class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation): """Data handler for a single station to be used by a superior data handler. Inputs are FIR filtered.""" - _requirements = remove_items(DataHandlerFilterSingleStation.requirements(), "station") _hash = DataHandlerFilterSingleStation._hash + ["filter_cutoff_period", "filter_order", "filter_window_type"] DEFAULT_WINDOW_TYPE = ("kaiser", 5) - def __init__(self, *args, filter_cutoff_period, filter_order, filter_window_type=DEFAULT_WINDOW_TYPE, **kwargs): + def __init__(self, *args, filter_cutoff_period, filter_order, filter_window_type=DEFAULT_WINDOW_TYPE, plot_path=None, **kwargs): # self.original_data = None # ToDo: implement here something to store unfiltered data self.fs = self._get_fs(**kwargs) if filter_window_type == "kzf": @@ -147,6 +135,7 @@ class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation): self.filter_order = self._prepare_filter_order(filter_order, removed_index, self.fs) self.filter_window_type = filter_window_type self.unfiltered_name = "unfiltered" + self.plot_path = plot_path # use this path to create insight plots super().__init__(*args, **kwargs) @staticmethod @@ -165,14 +154,11 @@ class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation): @staticmethod def _prepare_filter_cutoff_period(filter_cutoff_period, fs): """Frequency must be smaller than the sampling frequency fs. Otherwise remove given cutoff period pair.""" - cutoff_tmp = (lambda x: [x] if isinstance(x, tuple) else to_list(x))(filter_cutoff_period) cutoff = [] removed = [] - for i, (low, high) in enumerate(cutoff_tmp): - low = low if (low is None or low > 2. / fs) else None - high = high if (high is None or high > 2. / fs) else None - if any([low, high]): - cutoff.append((low, high)) + for i, period in enumerate(to_list(filter_cutoff_period)): + if period > 2. / fs: + cutoff.append(period) else: removed.append(i) return cutoff, removed @@ -187,8 +173,7 @@ class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation): @staticmethod def _period_to_freq(cutoff_p): - return list(map(lambda x: (1. / x[0] if x[0] is not None else None, 1. / x[1] if x[1] is not None else None), - cutoff_p)) + return [1. / x for x in cutoff_p] @staticmethod def _get_fs(**kwargs): @@ -205,10 +190,11 @@ class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation): def apply_filter(self): """Apply FIR filter only on inputs.""" fir = FIRFilter(self.input_data.astype("float32"), self.fs, self.filter_order, self.filter_cutoff_freq, - self.filter_window_type, self.target_dim) - self.fir_coeff = fir.filter_coefficients() - fir_data = fir.filtered_data() - self.input_data = xr.concat(fir_data, pd.Index(self.create_filter_index(), name=self.filter_dim)) + self.filter_window_type, self.target_dim, self.time_dim, station_name=self.station[0], + minimum_length=self.window_history_size, offset=self.window_history_offset, plot_path=self.plot_path) + self.fir_coeff = fir.filter_coefficients + filter_data = fir.filtered_data + self.input_data = xr.concat(filter_data, pd.Index(self.create_filter_index(), name=self.filter_dim)) # this is just a code snippet to check the results of the kz filter # import matplotlib # matplotlib.use("TkAgg") @@ -216,22 +202,17 @@ class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation): # self.input_data.sel(filter="low", variables="temp", Stations="DEBW107").plot() # self.input_data.sel(variables="temp", Stations="DEBW107").plot.line(hue="filter") - def create_filter_index(self) -> pd.Index: + def create_filter_index(self, add_unfiltered_index=True) -> pd.Index: """ - Create name for filter dimension. Use 'high' or 'low' for high/low pass data and 'bandi' for band pass data with - increasing numerator i (starting from 1). If 1 low, 2 band, and 1 high pass filter is used the filter index will - become to ['low', 'band1', 'band2', 'high']. + Round cut off periods in days and append 'res' for residuum index. + + Round small numbers (<10) to single decimal, and higher numbers to int. Transform as list of str and append + 'res' for residuum index. Add index unfiltered if the raw / unfiltered data is appended to data in addition. """ - index = [] - band_num = 1 - for (low, high) in self.filter_cutoff_period: - if low is None: - index.append("low") - elif high is None: - index.append("high") - else: - index.append(f"band{band_num}") - band_num += 1 + index = np.round(self.filter_cutoff_period, 1) + f = lambda x: int(np.round(x)) if x >= 10 else np.round(x, 1) + index = list(map(f, index.tolist())) + index = list(map(lambda x: str(x) + "d", index)) + ["res"] self.filter_dim_order = index return pd.Index(index, name=self.filter_dim) @@ -240,7 +221,7 @@ class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation): def _extract_lazy(self, lazy_data): _data, _meta, _input_data, _target_data, self.fir_coeff, self.filter_dim_order = lazy_data - super(__class__, self)._extract_lazy((_data, _meta, _input_data, _target_data)) + super()._extract_lazy((_data, _meta, _input_data, _target_data)) def transform(self, data_in, dim: Union[str, int] = 0, inverse: bool = False, opts=None, transformation_dim=None): @@ -325,67 +306,6 @@ class DataHandlerFirFilter(DataHandlerFilter): data_handler = DataHandlerFirFilterSingleStation data_handler_transformation = DataHandlerFirFilterSingleStation - - -class DataHandlerKzFilterSingleStation(DataHandlerFilterSingleStation): - """Data handler for a single station to be used by a superior data handler. Inputs are kz filtered.""" - - _requirements = remove_items(inspect.getfullargspec(DataHandlerFilterSingleStation).args, ["self", "station"]) - _hash = DataHandlerFilterSingleStation._hash + ["kz_filter_length", "kz_filter_iter"] - - def __init__(self, *args, kz_filter_length, kz_filter_iter, **kwargs): - self._check_sampling(**kwargs) - # self.original_data = None # ToDo: implement here something to store unfiltered data - self.kz_filter_length = to_list(kz_filter_length) - self.kz_filter_iter = to_list(kz_filter_iter) - self.cutoff_period = None - self.cutoff_period_days = None - super().__init__(*args, **kwargs) - - @TimeTrackingWrapper - def apply_filter(self): - """Apply kolmogorov zurbenko filter only on inputs.""" - kz = KZFilter(self.input_data, wl=self.kz_filter_length, itr=self.kz_filter_iter, filter_dim=self.time_dim) - filtered_data: List[xr.DataArray] = kz.run() - self.cutoff_period = kz.period_null() - self.cutoff_period_days = kz.period_null_days() - self.input_data = xr.concat(filtered_data, pd.Index(self.create_filter_index(), name=self.filter_dim)) - # this is just a code snippet to check the results of the kz filter - # import matplotlib - # matplotlib.use("TkAgg") - # import matplotlib.pyplot as plt - # self.input_data.sel(filter="74d", variables="temp", Stations="DEBW107").plot() - # self.input_data.sel(variables="temp", Stations="DEBW107").plot.line(hue="filter") - - def create_filter_index(self) -> pd.Index: - """ - Round cut off periods in days and append 'res' for residuum index. - - Round small numbers (<10) to single decimal, and higher numbers to int. Transform as list of str and append - 'res' for residuum index. - """ - index = np.round(self.cutoff_period_days, 1) - f = lambda x: int(np.round(x)) if x >= 10 else np.round(x, 1) - index = list(map(f, index.tolist())) - index = list(map(lambda x: str(x) + "d", index)) + ["res"] - self.filter_dim_order = index - return pd.Index(index, name=self.filter_dim) - - def _create_lazy_data(self): - return [self._data, self.meta, self.input_data, self.target_data, self.cutoff_period, self.cutoff_period_days, - self.filter_dim_order] - - def _extract_lazy(self, lazy_data): - _data, _meta, _input_data, _target_data, self.cutoff_period, self.cutoff_period_days, \ - self.filter_dim_order = lazy_data - super(__class__, self)._extract_lazy((_data, _meta, _input_data, _target_data)) - - -class DataHandlerKzFilter(DataHandlerFilter): - """Data handler using kz filtered data.""" - - data_handler = DataHandlerKzFilterSingleStation - data_handler_transformation = DataHandlerKzFilterSingleStation _requirements = data_handler.requirements() @@ -407,21 +327,20 @@ class DataHandlerClimateFirFilterSingleStation(DataHandlerFirFilterSingleStation :param apriori_diurnal: use diurnal anomalies of each hour as addition to the apriori information type chosen by parameter apriori_type. This is only applicable for hourly resolution data. """ - - _requirements = remove_items(DataHandlerFirFilterSingleStation.requirements(), "station") - _hash = DataHandlerFirFilterSingleStation._hash + ["apriori_type", "apriori_sel_opts", "apriori_diurnal"] + _hash = DataHandlerFirFilterSingleStation._hash + ["apriori_type", "apriori_sel_opts", "apriori_diurnal", + "extend_length_opts"] _store_attributes = DataHandlerFirFilterSingleStation.store_attributes() + ["apriori"] def __init__(self, *args, apriori=None, apriori_type=None, apriori_diurnal=False, apriori_sel_opts=None, - plot_path=None, name_affix=None, **kwargs): + name_affix=None, extend_length_opts=None, **kwargs): self.apriori_type = apriori_type self.climate_filter_coeff = None # coefficents of the used FIR filter self.apriori = apriori # exogenous apriori information or None to calculate from data (endogenous) self.apriori_diurnal = apriori_diurnal self.all_apriori = None # collection of all apriori information self.apriori_sel_opts = apriori_sel_opts # ensure to separate exogenous and endogenous information - self.plot_path = plot_path # use this path to create insight plots self.plot_name_affix = name_affix + self.extend_length_opts = extend_length_opts if extend_length_opts is not None else {} super().__init__(*args, **kwargs) @TimeTrackingWrapper @@ -429,14 +348,14 @@ class DataHandlerClimateFirFilterSingleStation(DataHandlerFirFilterSingleStation """Apply FIR filter only on inputs.""" self.apriori = self.apriori.get(str(self)) if isinstance(self.apriori, dict) else self.apriori logging.info(f"{self.station}: call ClimateFIRFilter") - plot_name = str(self) # if self.plot_name_affix is None else f"{str(self)}_{self.plot_name_affix}" climate_filter = ClimateFIRFilter(self.input_data.astype("float32"), self.fs, self.filter_order, self.filter_cutoff_freq, self.filter_window_type, time_dim=self.time_dim, var_dim=self.target_dim, apriori_type=self.apriori_type, apriori=self.apriori, apriori_diurnal=self.apriori_diurnal, sel_opts=self.apriori_sel_opts, - plot_path=self.plot_path, plot_name=plot_name, - minimum_length=self.window_history_size, new_dim=self.window_dim) + plot_path=self.plot_path, + minimum_length=self.window_history_size, new_dim=self.window_dim, + station_name=self.station[0], extend_length_opts=self.extend_length_opts) self.climate_filter_coeff = climate_filter.filter_coefficients # store apriori information: store all if residuum_stat method was used, otherwise just store initial apriori @@ -446,8 +365,18 @@ class DataHandlerClimateFirFilterSingleStation(DataHandlerFirFilterSingleStation self.apriori = climate_filter.initial_apriori_data self.all_apriori = climate_filter.apriori_data - climate_filter_data = [c.sel({self.window_dim: slice(-self.window_history_size, 0)}) for c in - climate_filter.filtered_data] + if isinstance(self.extend_length_opts, int): + climate_filter_data = [c.sel({self.window_dim: slice(-self.window_history_size, self.extend_length_opts)}) + for c in climate_filter.filtered_data] + else: + climate_filter_data = [] + for c in climate_filter.filtered_data: + coll_tmp = [] + for v in c.coords[self.target_dim].values: + upper_lim = self.extend_length_opts.get(v, 0) + coll_tmp.append(c.sel({self.target_dim: v, + self.window_dim: slice(-self.window_history_size, upper_lim)})) + climate_filter_data.append(xr.concat(coll_tmp, self.target_dim)) # create input data with filter index input_data = xr.concat(climate_filter_data, pd.Index(self.create_filter_index(add_unfiltered_index=False), diff --git a/mlair/data_handler/data_handler_wrf_chem.py b/mlair/data_handler/data_handler_wrf_chem.py index 4bf0ca1e0d517d97fb9ddb1aa7d36d762fa69541..15837977c7a0c8f937daa697075134ab9d44b7d5 100644 --- a/mlair/data_handler/data_handler_wrf_chem.py +++ b/mlair/data_handler/data_handler_wrf_chem.py @@ -77,7 +77,7 @@ class BaseWrfChemDataLoader: staged_rotation_opts: Dict = DEFAULT_STAGED_ROTATION_opts, vars_to_rotate: Tuple[Tuple[Tuple[str, str], Tuple[str, str]]] = DEFAULT_VARS_TO_ROTATE, staged_dimension_mapping=None, stag_ending='_stag', - date_format_of_nc_file=None, + date_format_of_nc_file=None, vars_for_unit_conv: Dict = None ): """ Initialisze data loader @@ -141,6 +141,10 @@ class BaseWrfChemDataLoader: else: self.staged_dimension_mapping = staged_dimension_mapping + self.vars_for_unit_conv = vars_for_unit_conv + # chemical convs. + self._parts_per_exponents = {'ppmv': 6, 'ppbv': 9, 'pptv': 12, 'ppqv': 15} + # internal self._X = None self._Y = None @@ -189,6 +193,10 @@ class BaseWrfChemDataLoader: else: raise ValueError(f"`start_time' and `end_time' must both be given or None.") + def convert_chem(self, data, from_unit, to_unit): + convert_exponent = self._parts_per_exponents[to_unit] - self._parts_per_exponents[from_unit] + return data * 10**convert_exponent + @TimeTrackingWrapper def open_data(self): """ @@ -492,6 +500,9 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader): self.apply_staged_transormation() self._set_geoinfos() + if self.vars_for_unit_conv is not None: + self.convert_chemical_units() + if self.lazy is False: self.reset_data_by_other(self.apply_toarstats()) else: @@ -544,6 +555,14 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader): hash = "".join([str(self.__getattribute__(e)) for e in self._hash_list()]).encode() return hashlib.md5(hash).hexdigest() + def convert_chemical_units(self): + with xr.set_options(keep_attrs=True): + for var, to_unit in self.vars_for_unit_conv.items(): + from_unit = self.data[var].attrs['units'] + data = self.convert_chem(self.data[var], from_unit, to_unit) + data.attrs['units'] = to_unit + self.data[var] = data + def __exit__(self, exc_type, exc_val, exc_tb): self.data.close() gc.collect() @@ -816,6 +835,8 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): time_zone=None, target_time_type=None, input_output_sampling4toarstats : tuple = None, + experiment_path: str = None, + vars_for_unit_conv: Dict = None, **kwargs): self.external_coords_file = external_coords_file self.var_logical_z_coord_selector = self._return_z_coord_select_if_valid(var_logical_z_coord_selector, @@ -833,6 +854,8 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): self.time_zone = time_zone self.target_time_type = target_time_type self.input_output_sampling4toarstats = input_output_sampling4toarstats + self.experiment_path = experiment_path + self.vars_for_unit_conv = vars_for_unit_conv super().__init__(*args, **kwargs) @staticmethod @@ -920,6 +943,7 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): target_time_type=self.target_time_type, station=self.station, lazy_preprocessing=True, + vars_for_unit_conv=self.vars_for_unit_conv, ) self.__loader = loader @@ -1209,8 +1233,30 @@ class DataHandlerSectorGrid(DataHandlerSingleGridColumn): wind_sector_edge_dim_name=self.wind_sector_edge_dim_name) self._added_vars = [] self.wind_dir_name = None + self._wind_upstream_sector_by_name = None + # self.wind_upstream_sector_by_name = None super().__init__(*args, **kwargs) + def get_wind_upstream_sector_by_name(self): + return self.wind_upstream_sector_by_name + + @property + def wind_upstream_sector_by_name(self): + return self._wind_upstream_sector_by_name + + @wind_upstream_sector_by_name.setter + def wind_upstream_sector_by_name(self, wind_upstream_sector_by_name: xr.DataArray): + self._wind_upstream_sector_by_name = wind_upstream_sector_by_name + + def _store_wind_upstream_sector_by_name(self): + file_name = os.path.join(self.experiment_path, + f"data/{self.station[0]}_{self.start}_{self.end}_upstream_wind_sector.nc") + wind_upstream_sector_by_name = self.wind_upstream_sector_by_name + dims_to_expand = list(wind_upstream_sector_by_name.coords._names - set(wind_upstream_sector_by_name.dims)) + wind_upstream_sector_by_name = wind_upstream_sector_by_name.expand_dims(dims_to_expand) + wind_upstream_sector_by_name = wind_upstream_sector_by_name.to_dataset(self.iter_dim) + wind_upstream_sector_by_name.to_netcdf(file_name) + @TimeTrackingWrapper def extract_data_from_loader(self, loader): wind_dir_name = self._get_wind_dir_var_name(loader) @@ -1274,12 +1320,7 @@ class DataHandlerSectorGrid(DataHandlerSingleGridColumn): @TimeTrackingWrapper def modify_history(self): if self.transformation_is_applied: - ws_edges = self.get_applied_transdormation_on_wind_sector_edges() - wind_dir_of_interest = self.compute_wind_dir_of_interest() - sector_allocation = self.windsector.get_sect_of_value(value=wind_dir_of_interest, external_edges=ws_edges) - sector_allocation = sector_allocation.squeeze() - existing_sectors = np.unique(sector_allocation.data) - sector_history, sector_history_var_names = self.setup_history_like_xr_and_var_names() + existing_sectors, sector_allocation, sector_history, sector_history_var_names = self.prepare_sector_allocation_and_history() with self.loader as loader, TimeTracking(name="loader in modify history"): # setup sector history grid_data = self.preselect_and_transform_neighbouring_data_based_on_radius(loader) @@ -1295,6 +1336,17 @@ class DataHandlerSectorGrid(DataHandlerSingleGridColumn): else: return self.history + def prepare_sector_allocation_and_history(self): + ws_edges = self.get_applied_transdormation_on_wind_sector_edges() + wind_dir_of_interest = self.compute_wind_dir_of_interest() + sector_allocation = self.windsector.get_sect_of_value(value=wind_dir_of_interest, external_edges=ws_edges) + sector_allocation = sector_allocation.squeeze() + existing_sectors = np.unique(sector_allocation.data) + sector_history, sector_history_var_names = self.setup_history_like_xr_and_var_names() + self.wind_upstream_sector_by_name = sector_allocation + self._store_wind_upstream_sector_by_name() + return existing_sectors, sector_allocation, sector_history, sector_history_var_names + def setup_history_like_xr_and_var_names(self, var_name_suffix="Sect"): """ Returns ones_like xarray from self.history and list of variable names which can be modified by passing a @@ -1406,11 +1458,7 @@ class DataHandler3SectorGrid(DataHandlerSectorGrid): @TimeTrackingWrapper def modify_history(self): if self.transformation_is_applied: - ws_edges = self.get_applied_transdormation_on_wind_sector_edges() - wind_dir_of_interest = self.compute_wind_dir_of_interest() - sector_allocation = self.windsector.get_sect_of_value(value=wind_dir_of_interest, external_edges=ws_edges) - existing_sectors = np.unique(sector_allocation.data) - sector_history, sector_history_var_names = self.setup_history_like_xr_and_var_names() + existing_sectors, sector_allocation, sector_history, sector_history_var_names = self.prepare_sector_allocation_and_history() sector_history_left, sector_history_var_names_left = self.setup_history_like_xr_and_var_names(var_name_suffix="SectLeft") sector_history_right, sector_history_var_names_right = self.setup_history_like_xr_and_var_names(var_name_suffix="SectRight") with self.loader as loader, TimeTracking(name="loader in modify history"): diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py index 3649eaf65ed7efe45c5fac54f892bd1e471e5838..0f97d90c14dc3f6db7243b277ab45680223c6a25 100644 --- a/mlair/data_handler/default_data_handler.py +++ b/mlair/data_handler/default_data_handler.py @@ -20,7 +20,7 @@ import numpy as np import xarray as xr from mlair.data_handler.abstract_data_handler import AbstractDataHandler -from mlair.helpers import remove_items, to_list +from mlair.helpers import remove_items, to_list, TimeTrackingWrapper from mlair.helpers.join import EmptyQueryResult @@ -32,8 +32,9 @@ class DefaultDataHandler(AbstractDataHandler): from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation as data_handler from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation as data_handler_transformation - _requirements = remove_items(inspect.getfullargspec(data_handler).args, ["self", "station"]) + _requirements = data_handler.requirements() _store_attributes = data_handler.store_attributes() + _skip_args = AbstractDataHandler._skip_args + ["id_class"] DEFAULT_ITER_DIM = "Stations" DEFAULT_TIME_DIM = "datetime" @@ -76,10 +77,6 @@ class DefaultDataHandler(AbstractDataHandler): def _create_collection(self): return [self.id_class] - @classmethod - def requirements(cls): - return remove_items(super().requirements(), "id_class") - def _reset_data(self): self._X, self._Y, self._X_extreme, self._Y_extreme = None, None, None, None gc.collect() @@ -159,6 +156,7 @@ class DefaultDataHandler(AbstractDataHandler): self._reset_data() if no_data is True else None return self._to_numpy([Y]) if as_numpy is True else Y + @TimeTrackingWrapper def harmonise_X(self): X_original, Y_original = self.get_X_original(), self.get_Y_original() dim = self.time_dim @@ -181,6 +179,7 @@ class DefaultDataHandler(AbstractDataHandler): def apply_transformation(self, data, base="target", dim=0, inverse=False): return self.id_class.apply_transformation(data, dim=dim, base=base, inverse=inverse) + @TimeTrackingWrapper def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False, timedelta: Tuple[int, str] = (1, 'm'), dim=DEFAULT_TIME_DIM): """ @@ -288,6 +287,7 @@ class DefaultDataHandler(AbstractDataHandler): transformation_dict = ({}, {}) max_process = kwargs.get("max_number_multiprocessing", 16) + set_stations = to_list(set_stations) n_process = min([psutil.cpu_count(logical=False), len(set_stations), max_process]) # use only physical cpus if n_process > 1 and kwargs.get("use_multiprocessing", True) is True: # parallel solution logging.info("use parallel transformation approach") @@ -304,6 +304,7 @@ class DefaultDataHandler(AbstractDataHandler): os.remove(_res_file) transformation_dict = cls.update_transformation_dict(dh, transformation_dict) pool.close() + pool.join() else: # serial solution logging.info("use serial transformation approach") sp_keys.update({"return_strategy": "result"}) diff --git a/mlair/data_handler/input_bootstraps.py b/mlair/data_handler/input_bootstraps.py index ab4c71f1c77ab12b0e751f6991fbf20cd55aa8ae..b8ad614f2317e804d415b23308df760f4dd8da7f 100644 --- a/mlair/data_handler/input_bootstraps.py +++ b/mlair/data_handler/input_bootstraps.py @@ -28,12 +28,13 @@ class BootstrapIterator(Iterator): _position: int = None - def __init__(self, data: "Bootstraps", method): + def __init__(self, data: "Bootstraps", method, return_reshaped=False): assert isinstance(data, Bootstraps) self._data = data self._dimension = data.bootstrap_dimension self.boot_dim = "boots" self._method = method + self._return_reshaped = return_reshaped self._collection = self.create_collection(self._data.data, self._dimension) self._position = 0 @@ -46,12 +47,15 @@ class BootstrapIterator(Iterator): raise NotImplementedError def _reshape(self, d): - if isinstance(d, list): - return list(map(lambda x: self._reshape(x), d)) - # return list(map(lambda x: np.rollaxis(x, -1, 0).reshape(x.shape[0] * x.shape[-1], *x.shape[1:-1]), d)) + if self._return_reshaped: + if isinstance(d, list): + return list(map(lambda x: self._reshape(x), d)) + # return list(map(lambda x: np.rollaxis(x, -1, 0).reshape(x.shape[0] * x.shape[-1], *x.shape[1:-1]), d)) + else: + shape = d.shape + return np.rollaxis(d, -1, 0).reshape(shape[0] * shape[-1], *shape[1:-1]) else: - shape = d.shape - return np.rollaxis(d, -1, 0).reshape(shape[0] * shape[-1], *shape[1:-1]) + return d def _to_numpy(self, d): if isinstance(d, list): @@ -75,8 +79,8 @@ class BootstrapIterator(Iterator): class BootstrapIteratorSingleInput(BootstrapIterator): _position: int = None - def __init__(self, *args): - super().__init__(*args) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) def __next__(self): """Return next element or stop iteration.""" @@ -107,8 +111,8 @@ class BootstrapIteratorSingleInput(BootstrapIterator): class BootstrapIteratorVariable(BootstrapIterator): - def __init__(self, *args): - super().__init__(*args) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) def __next__(self): """Return next element or stop iteration.""" @@ -119,11 +123,12 @@ class BootstrapIteratorVariable(BootstrapIterator): _X = list(map(lambda x: x.expand_dims({self.boot_dim: range(nboot)}, axis=-1), _X)) _Y = _Y.expand_dims({self.boot_dim: range(nboot)}, axis=-1) for index in range(len(_X)): - single_variable = _X[index].sel({self._dimension: [dimension]}) - bootstrapped_variable = self.apply_bootstrap_method(single_variable.values) - bootstrapped_data = xr.DataArray(bootstrapped_variable, coords=single_variable.coords, - dims=single_variable.dims) - _X[index] = bootstrapped_data.combine_first(_X[index]).transpose(*_X[index].dims) + if dimension in _X[index].coords[self._dimension]: + single_variable = _X[index].sel({self._dimension: [dimension]}) + bootstrapped_variable = self.apply_bootstrap_method(single_variable.values) + bootstrapped_data = xr.DataArray(bootstrapped_variable, coords=single_variable.coords, + dims=single_variable.dims) + _X[index] = bootstrapped_data.combine_first(_X[index]).transpose(*_X[index].dims) self._position += 1 except IndexError: raise StopIteration() @@ -140,8 +145,8 @@ class BootstrapIteratorVariable(BootstrapIterator): class BootstrapIteratorBranch(BootstrapIterator): - def __init__(self, *args): - super().__init__(*args) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) def __next__(self): try: diff --git a/mlair/helpers/filter.py b/mlair/helpers/filter.py index 543cff3624577ac617733b8b593c5f52f25196b3..b25d6ee10f89bfa49c2147d1758a2d24b8e7687e 100644 --- a/mlair/helpers/filter.py +++ b/mlair/helpers/filter.py @@ -1,6 +1,6 @@ import gc import warnings -from typing import Union, Callable, Tuple +from typing import Union, Callable, Tuple, Dict, Any import logging import os import time @@ -17,49 +17,158 @@ from mlair.helpers import to_list, TimeTrackingWrapper, TimeTracking class FIRFilter: + from mlair.plotting.data_insight_plotting import PlotFirFilter + + def __init__(self, data, fs, order, cutoff, window, var_dim, time_dim, station_name=None, minimum_length=None, offset=0, plot_path=None): + self._filtered = [] + self._h = [] + self.data = data + self.fs = fs + self.order = order + self.cutoff = cutoff + self.window = window + self.var_dim = var_dim + self.time_dim = time_dim + self.station_name = station_name + self.minimum_length = minimum_length + self.offset = offset + self.plot_path = plot_path + self.run() - def __init__(self, data, fs, order, cutoff, window, dim): - + def run(self): + logging.info(f"{self.station_name}: start {self.__class__.__name__}") filtered = [] h = [] - for i in range(len(order)): - fi, hi = fir_filter(data, fs, order=order[i], cutoff_low=cutoff[i][0], cutoff_high=cutoff[i][1], - window=window, dim=dim, h=None, causal=True, padlen=None) + input_data = self.data.__deepcopy__() + + # collect some data for visualization + plot_pos = np.array([0.25, 1.5, 2.75, 4]) * 365 * self.fs + plot_dates = [input_data.isel({self.time_dim: int(pos)}).coords[self.time_dim].values for pos in plot_pos if + pos < len(input_data.coords[self.time_dim])] + plot_data = [] + + for i in range(len(self.order)): + # apply filter + fi, hi = self.fir_filter(input_data, self.fs, self.cutoff[i], self.order[i], time_dim=self.time_dim, + var_dim=self.var_dim, window=self.window, station_name=self.station_name) filtered.append(fi) h.append(hi) + # visualization + plot_data.append(self.create_visualization(fi, input_data, plot_dates, self.time_dim, self.fs, hi, + self.minimum_length, self.order, i, self.offset, self.var_dim)) + # calculate residuum + input_data = input_data - fi + + # add last residuum to filtered + filtered.append(input_data) + self._filtered = filtered self._h = h + # visualize + if self.plot_path is not None: + try: + self.PlotFirFilter(self.plot_path, plot_data, self.station_name) # not working when t0 != 0 + except Exception as e: + logging.info(f"Could not plot climate fir filter due to following reason:\n{e}") + + def create_visualization(self, filtered, filter_input_data, plot_dates, time_dim, sampling, + h, minimum_length, order, i, offset, var_dim): # pragma: no cover + plot_data = [] + minimum_length = minimum_length or 0 + for viz_date in set(plot_dates).intersection(filtered.coords[time_dim].values): + try: + if i < len(order) - 1: + minimum_length += order[i+1] + + td_type = {1: "D", 24: "h"}.get(sampling) + length = len(h) + extend_length_history = minimum_length + int((length + 1) / 2) + extend_length_future = int((length + 1) / 2) + 1 + t_minus = viz_date + np.timedelta64(int(-extend_length_history), td_type) + t_plus = viz_date + np.timedelta64(int(extend_length_future + offset), td_type) + time_slice = slice(t_minus, t_plus - np.timedelta64(1, td_type)) + plot_data.append({"t0": viz_date, "filter_input": filter_input_data.sel({time_dim: time_slice}), + "filtered": filtered.sel({time_dim: time_slice}), "h": h, "time_dim": time_dim, + "var_dim": var_dim}) + except: + pass + return plot_data + + @property def filter_coefficients(self): return self._h + @property def filtered_data(self): return self._filtered - # - # y, h = fir_filter(station_data.values.flatten(), fs, order[0], cutoff_low=cutoff[0][0], cutoff_high=cutoff[0][1], - # window=window) - # filtered = xr.ones_like(station_data) * y.reshape(station_data.values.shape) - # # band pass - # y_band, h_band = fir_filter(station_data.values.flatten(), fs, order[1], cutoff_low=cutoff[1][0], - # cutoff_high=cutoff[1][1], window=window) - # filtered_band = xr.ones_like(station_data) * y_band.reshape(station_data.values.shape) - # # band pass 2 - # y_band_2, h_band_2 = fir_filter(station_data.values.flatten(), fs, order[2], cutoff_low=cutoff[2][0], - # cutoff_high=cutoff[2][1], window=window) - # filtered_band_2 = xr.ones_like(station_data) * y_band_2.reshape(station_data.values.shape) - # # high pass - # y_high, h_high = fir_filter(station_data.values.flatten(), fs, order[3], cutoff_low=cutoff[3][0], - # cutoff_high=cutoff[3][1], window=window) - # filtered_high = xr.ones_like(station_data) * y_high.reshape(station_data.values.shape) - - -class ClimateFIRFilter: + + @TimeTrackingWrapper + def fir_filter(self, data, fs, cutoff_high, order, sampling="1d", time_dim="datetime", var_dim="variables", window: Union[str, Tuple] = "hamming", + minimum_length=None, new_dim="window", plot_dates=None, station_name=None): + + # calculate FIR filter coefficients + h = self._calculate_filter_coefficients(window, order, cutoff_high, fs) + + coll = [] + for var in data.coords[var_dim]: + d = data.sel({var_dim: var}) + filt = xr.apply_ufunc(fir_filter_convolve, d, + input_core_dims=[[time_dim]], output_core_dims=[[time_dim]], + vectorize=True, kwargs={"h": h}, output_dtypes=[d.dtype]) + coll.append(filt) + filtered = xr.concat(coll, var_dim) + + # create result array with same shape like input data, gaps are filled by nans + filtered = self._create_full_filter_result_array(data, filtered, time_dim, station_name) + return filtered, h + + @staticmethod + def _calculate_filter_coefficients(window: Union[str, tuple], order: Union[int, tuple], cutoff_high: float, + fs: float) -> np.array: + """ + Calculate filter coefficients for moving window using scipy's signal package for common filter types and local + method firwin_kzf for Kolmogorov Zurbenko filter (kzf). The filter is a low-pass filter. + + :param window: name of the window type which is either a string with the window's name or a tuple containing the + name but also some parameters (e.g. `("kaiser", 5)`) + :param order: order of the filter to create as int or parameters m and k of kzf + :param cutoff_high: cutoff frequency to use for low-pass filter in frequency of fs + :param fs: sampling frequency of time series + """ + if window == "kzf": + h = firwin_kzf(*order) + else: + h = signal.firwin(order, cutoff_high, pass_zero="lowpass", fs=fs, window=window) + return h + + @staticmethod + def _create_full_filter_result_array(template_array: xr.DataArray, result_array: xr.DataArray, new_dim: str, + station_name: str = None) -> xr.DataArray: + """ + Create result filter array with same shape line given template data (should be the original input data before + filtering the data). All gaps are filled by nans. + + :param template_array: this array is used as template for shape and ordering of dims + :param result_array: array with data that are filled into template + :param new_dim: new dimension which is shifted/appended to/at the end (if present or not) + :param station_name: string that is attached to logging (default None) + """ + logging.debug(f"{station_name}: create res_full") + new_coords = {**{k: template_array.coords[k].values for k in template_array.coords if k != new_dim}, + new_dim: result_array.coords[new_dim]} + dims = [*template_array.dims, new_dim] if new_dim not in template_array.dims else template_array.dims + result_array = result_array.transpose(*dims) + return result_array.broadcast_like(xr.DataArray(dims=dims, coords=new_coords)) + + +class ClimateFIRFilter(FIRFilter): from mlair.plotting.data_insight_plotting import PlotClimateFirFilter def __init__(self, data, fs, order, cutoff, window, time_dim, var_dim, apriori=None, apriori_type=None, - apriori_diurnal=False, sel_opts=None, plot_path=None, plot_name=None, - minimum_length=None, new_dim=None): + apriori_diurnal=False, sel_opts=None, plot_path=None, + minimum_length=None, new_dim=None, station_name=None, extend_length_opts: Union[dict, int] = 0): """ :param data: data to filter :param fs: sampling frequency in 1/days -> 1d: fs=1 -> 1H: fs=24 @@ -75,111 +184,115 @@ class ClimateFIRFilter: residua is used ("residuum_stats"). :param apriori_diurnal: Use diurnal cycle as additional apriori information (only applicable for hourly resoluted data). The mean anomaly of each hour is added to the apriori_type information. + :param extend_length_opts: shift information switch between historical data and apriori estimation by the given + values (default None). Must either be a dictionary with keys available in var_dim or a single value that is + applied to all data. """ - logging.info(f"{plot_name}: start init ClimateFIRFilter") - self.plot_path = plot_path - self.plot_name = plot_name + self._apriori = apriori + self.apriori_type = apriori_type + self.apriori_diurnal = apriori_diurnal + self._apriori_list = [] + self.sel_opts = sel_opts + self.new_dim = new_dim self.plot_data = [] + self.extend_length_opts = extend_length_opts + super().__init__(data, fs, order, cutoff, window, var_dim, time_dim, station_name=station_name, + minimum_length=minimum_length, plot_path=plot_path) + + def run(self): filtered = [] h = [] - if sel_opts is not None: - sel_opts = sel_opts if isinstance(sel_opts, dict) else {time_dim: sel_opts} - sampling = {1: "1d", 24: "1H"}.get(int(fs)) - logging.debug(f"{plot_name}: create diurnal_anomalies") - if apriori_diurnal is True and sampling == "1H": - # diurnal_anomalies = self.create_hourly_mean(data, sel_opts=sel_opts, sampling=sampling, time_dim=time_dim, - # as_anomaly=True) - diurnal_anomalies = self.create_seasonal_hourly_mean(data, sel_opts=sel_opts, sampling=sampling, - time_dim=time_dim, - as_anomaly=True) + if self.sel_opts is not None: + self.sel_opts = self.sel_opts if isinstance(self.sel_opts, dict) else {self.time_dim: self.sel_opts} + sampling = {1: "1d", 24: "1H"}.get(int(self.fs)) + logging.debug(f"{self.station_name}: create diurnal_anomalies") + if self.apriori_diurnal is True and sampling == "1H": + diurnal_anomalies = self.create_seasonal_hourly_mean(self.data, self.time_dim, sel_opts=self.sel_opts, + sampling=sampling, as_anomaly=True) else: diurnal_anomalies = 0 - logging.debug(f"{plot_name}: create monthly apriori") - if apriori is None: - apriori = self.create_monthly_mean(data, sel_opts=sel_opts, sampling=sampling, - time_dim=time_dim) + diurnal_anomalies - logging.debug(f"{plot_name}: apriori shape = {apriori.shape}") - apriori_list = to_list(apriori) - input_data = data.__deepcopy__() + logging.debug(f"{self.station_name}: create monthly apriori") + if self._apriori is None: + self._apriori = self.create_monthly_mean(self.data, self.time_dim, sel_opts=self.sel_opts, + sampling=sampling) + diurnal_anomalies + logging.debug(f"{self.station_name}: apriori shape = {self._apriori.shape}") + apriori_list = to_list(self._apriori) + input_data = self.data.__deepcopy__() # for viz plot_dates = None # create tmp dimension to apply filter, search for unused name - new_dim = self._create_tmp_dimension(input_data) if new_dim is None else new_dim + new_dim = self._create_tmp_dimension(input_data) if self.new_dim is None else self.new_dim - for i in range(len(order)): - logging.info(f"{plot_name}: start filter for order {order[i]}") + for i in range(len(self.order)): + logging.info(f"{self.station_name}: start filter for order {self.order[i]}") # calculate climatological filter - # ToDo: remove all methods except the vectorized version - _minimum_length = self._minimum_length(order, minimum_length, i, window) - fi, hi, apriori, plot_data = self.clim_filter(input_data, fs, cutoff[i], order[i], - apriori=apriori_list[i], - sel_opts=sel_opts, sampling=sampling, time_dim=time_dim, - window=window, var_dim=var_dim, + _minimum_length = self._minimum_length(self.order, self.minimum_length, i, self.window) + fi, hi, apriori, plot_data = self.clim_filter(input_data, self.fs, self.cutoff[i], self.order[i], + apriori=apriori_list[i], sel_opts=self.sel_opts, + sampling=sampling, time_dim=self.time_dim, + window=self.window, var_dim=self.var_dim, minimum_length=_minimum_length, new_dim=new_dim, - plot_dates=plot_dates) + plot_dates=plot_dates, station_name=self.station_name, + extend_length_opts=self.extend_length_opts) - logging.info(f"{plot_name}: finished clim_filter calculation") - if minimum_length is None: + logging.info(f"{self.station_name}: finished clim_filter calculation") + if self.minimum_length is None: filtered.append(fi) else: - filtered.append(fi.sel({new_dim: slice(-minimum_length, 0)})) + filtered.append(fi.sel({new_dim: slice(-self.minimum_length, None)})) h.append(hi) gc.collect() self.plot_data.append(plot_data) plot_dates = {e["t0"] for e in plot_data} # calculate residuum - logging.info(f"{plot_name}: calculate residuum") + logging.info(f"{self.station_name}: calculate residuum") coord_range = range(fi.coords[new_dim].values.min(), fi.coords[new_dim].values.max() + 1) if new_dim in input_data.coords: input_data = input_data.sel({new_dim: coord_range}) - fi else: - input_data = self._shift_data(input_data, coord_range, time_dim, var_dim, new_dim) - fi + input_data = self._shift_data(input_data, coord_range, self.time_dim, new_dim) - fi # create new apriori information for next iteration if no further apriori is provided - if len(apriori_list) <= i + 1: - logging.info(f"{plot_name}: create diurnal_anomalies") - if apriori_diurnal is True and sampling == "1H": - # diurnal_anomalies = self.create_hourly_mean(input_data.sel({new_dim: 0}, drop=True), - # sel_opts=sel_opts, sampling=sampling, - # time_dim=time_dim, as_anomaly=True) + if len(apriori_list) < len(self.order): + logging.info(f"{self.station_name}: create diurnal_anomalies") + if self.apriori_diurnal is True and sampling == "1H": diurnal_anomalies = self.create_seasonal_hourly_mean(input_data.sel({new_dim: 0}, drop=True), - sel_opts=sel_opts, sampling=sampling, - time_dim=time_dim, as_anomaly=True) + self.time_dim, sel_opts=self.sel_opts, + sampling=sampling, as_anomaly=True) else: diurnal_anomalies = 0 - logging.info(f"{plot_name}: create monthly apriori") - if apriori_type is None or apriori_type == "zeros": # zero version + logging.info(f"{self.station_name}: create monthly apriori") + if self.apriori_type is None or self.apriori_type == "zeros": # zero version apriori_list.append(xr.zeros_like(apriori_list[i]) + diurnal_anomalies) - elif apriori_type == "residuum_stats": # calculate monthly statistic on residuum + elif self.apriori_type == "residuum_stats": # calculate monthly statistic on residuum apriori_list.append( - -self.create_monthly_mean(input_data.sel({new_dim: 0}, drop=True), sel_opts=sel_opts, - sampling=sampling, - time_dim=time_dim) + diurnal_anomalies) + -self.create_monthly_mean(input_data.sel({new_dim: 0}, drop=True), self.time_dim, + sel_opts=self.sel_opts, sampling=sampling) + diurnal_anomalies) else: - raise ValueError(f"Cannot handle unkown apriori type: {apriori_type}. Please choose from None, " - f"`zeros` or `residuum_stats`.") + raise ValueError(f"Cannot handle unkown apriori type: {self.apriori_type}. Please choose from None," + f" `zeros` or `residuum_stats`.") # add last residuum to filtered - if minimum_length is None: + if self.minimum_length is None: filtered.append(input_data) else: - filtered.append(input_data.sel({new_dim: slice(-minimum_length, 0)})) - # filtered.append(input_data) + filtered.append(input_data.sel({new_dim: slice(-self.minimum_length, None)})) + self._filtered = filtered self._h = h - self._apriori = apriori_list + self._apriori_list = apriori_list # visualize if self.plot_path is not None: try: - self.PlotClimateFirFilter(self.plot_path, self.plot_data, sampling, plot_name) + self.PlotClimateFirFilter(self.plot_path, self.plot_data, sampling, self.station_name) except Exception as e: logging.info(f"Could not plot climate fir filter due to following reason:\n{e}") @staticmethod - def _minimum_length(order, minimum_length, pos, window): + def _minimum_length(order: list, minimum_length: Union[int, None], pos: int, window: Union[str, tuple]) -> int: next_order = 0 if pos + 1 < len(order): next_order = order[pos + 1] @@ -190,8 +303,16 @@ class ClimateFIRFilter: return next_order if next_order > 0 else None @staticmethod - def create_unity_array(data, time_dim, extend_range=366): - """Create a xr data array filled with ones. time_dim is extended by extend_range days in future and past.""" + def create_monthly_unity_array(data: xr.DataArray, time_dim: str, extend_range: int = 366) -> xr.DataArray: + """ + Create a xarray data array filled with ones with monthly resolution (set on 16th of month). Data is extended by + extend_range days in future and past along time_dim. + + :param data: data to create monthly unity array from, must contain dimension time_dim + :param time_dim: name of temporal dimension + :param extend_range: number of days to extend data (default 366) + :returns: xarray in monthly resolution (centered at 16th day of month) with all values equal to 1 + """ coords = data.coords # extend time_dim by given extend_range days @@ -206,11 +327,28 @@ class ClimateFIRFilter: # loffset is required because resampling uses last day in month as resampling timestamp return new_array.resample({time_dim: "1m"}, loffset=datetime.timedelta(days=-15)).max() - def create_monthly_mean(self, data, sel_opts=None, sampling="1d", time_dim="datetime"): - """Calculate monthly statistics.""" + def create_monthly_mean(self, data: xr.DataArray, time_dim: str, sel_opts: dict = None, sampling: str = "1d") \ + -> xr.DataArray: + """ + Calculate monthly means (12 values) and return a data array with same resolution as given data containing these + monthly mean values. Sampling points are the 16th of each month (this value is equal to the true monthly mean) + and all other values between two points are interpolated linearly. It is possible to apply some pre-selection + to use only a subset of given data using the sel_opts parameter. Only data from this subset are used to + calculate the monthly statistic. + + :param data: data to apply statistical calculation on + :param time_dim: name of temporal axis + :param sel_opts: selection options as dict to select a subset of data (default None). A given sel_opts with + `sel_opts={<time_dim>: "2006"}` forces the method e.g. to derive the monthly means only from data of the + year 2006. + :param sampling: sampling of the returned data (default 1d) + :returns: array in desired resolution containing interpolated monthly values. Months with no valid data are + returned as np.nan which also effects data in the neighbouring months (before / after sampling points which + are the 16th of each month). + """ # create unity xarray in monthly resolution with sampling point in mid of each month - monthly = self.create_unity_array(data, time_dim) + monthly = self.create_monthly_unity_array(data, time_dim) * np.nan # apply selection if given (only use subset for monthly means) if sel_opts is not None: @@ -225,35 +363,68 @@ class ClimateFIRFilter: # transform monthly information into original sampling rate return monthly.resample({time_dim: sampling}).interpolate() - # for month in monthly_mean.month.values: - # loc = (monthly[f"{time_dim}.month"] == month) - # monthly.loc[{time_dim: loc}] = monthly_mean.sel(month=month, drop=True) - # aggregate monthly information (shift by half month, because resample base is last day) - # return monthly.resample({time_dim: "1m"}).max().resample({time_dim: sampling}).interpolate() - @staticmethod - def create_hourly_mean(data, sel_opts=None, sampling="1H", time_dim="datetime", as_anomaly=True): - """Calculate hourly statistics. Either the absolute value or the anomaly (as_anomaly=True).""" - # can only be used for hourly sampling rate - assert sampling == "1H" - - # create unity xarray in hourly resolution - hourly = xr.ones_like(data) + def _compute_hourly_mean_per_month(data: xr.DataArray, time_dim: str, as_anomaly: bool) -> Dict[int, xr.DataArray]: + """ + Calculate for each hour in each month a separate mean value (12 x 24 values in total). Average is either the + anomaly of a monthly mean state or the raw mean value. - # apply selection if given (only use subset for hourly means) - if sel_opts is not None: - data = data.sel(**sel_opts) + :param data: data to calculate averages on + :param time_dim: name of temporal dimension + :param as_anomaly: indicates whether to calculate means as anomaly of a monthly mean or as raw mean values. + :returns: dictionary containing 12 months each with a 24-valued array (1 entry for each hour) + """ + seasonal_hourly_means = {} + for month in data.groupby(f"{time_dim}.month").groups.keys(): + single_month_data = data.sel({time_dim: (data[f"{time_dim}.month"] == month)}) + hourly_mean = single_month_data.groupby(f"{time_dim}.hour").mean() + if as_anomaly is True: + hourly_mean = hourly_mean - hourly_mean.mean("hour") + seasonal_hourly_means[month] = hourly_mean + return seasonal_hourly_means - # create mean for each hour and replace entries in unity array, calculate anomaly if enabled - hourly_mean = data.groupby(f"{time_dim}.hour").mean() - if as_anomaly is True: - hourly_mean = hourly_mean - hourly_mean.mean("hour") - for hour in hourly_mean.hour.values: - loc = (hourly[f"{time_dim}.hour"] == hour) - hourly.loc[{f"{time_dim}": loc}] = hourly_mean.sel(hour=hour) - return hourly + @staticmethod + def _create_seasonal_cycle_of_single_hour_mean(result_arr: xr.DataArray, means: Dict[int, xr.DataArray], hour: int, + time_dim: str, sampling: str) -> xr.DataArray: + """ + Use monthly means of a given hour to create an array with interpolated values at the indicated hour for each day + of the full time span indicated by given result_arr. + + :param result_arr: template array indicating the full time range and additional dimensions to keep + :param means: dictionary containing 24 hourly averages for each month (12 x 24 values in total) + :param hour: integer of hour of interest + :param time_dim: name of temporal dimension + :param sampling: sampling rate to interpolate + :returns: array with interpolated averages in sampling resolution containing only values for hour of interest + """ + h_coll = xr.ones_like(result_arr) * np.nan + for month in means.keys(): + hourly_mean_single_month = means[month].sel(hour=hour, drop=True) + h_coll = xr.where((h_coll[f"{time_dim}.month"] == month), hourly_mean_single_month, h_coll) + h_coll = h_coll.resample({time_dim: sampling}).interpolate() + h_coll = h_coll.sel({time_dim: (h_coll[f"{time_dim}.hour"] == hour)}) + return h_coll + + def create_seasonal_hourly_mean(self, data: xr.DataArray, time_dim: str, sel_opts: Dict[str, Any] = None, + sampling: str = "1H", as_anomaly: bool = True) -> xr.DataArray: + """ + Compute climatological statistics on hourly base either as raw data or anomalies. For each month, an overall + mean value (only used if requiring anomalies) and the mean of each hour are calculated. The climatological + diurnal cycle is positioned on the 16th of each month and interpolated in between by using a distinct + interpolation for each hour of day. The returned array therefore contains data with a yearly cycle (if anomaly + is not calculated) or data without a yearly cycle (if using anomalies). In both cases, the data have an + amplitude that varies over the year. + + :param data: data to apply this method to + :param time_dim: name of temporal axis + :param sel_opts: specific selection options that are applied before calculation of climatological statistics + (default None) + :param sampling: temporal resolution of data (default "1H") + :param as_anomaly: specify whether to use anomalies or raw data including a seasonal cycle of the mean value + (default: True) + :returns: climatological statistics for given data interpolated with given sampling rate + """ - def create_seasonal_hourly_mean(self, data, sel_opts=None, sampling="1H", time_dim="datetime", as_anomaly=True): """Calculate hourly statistics. Either the absolute value or the anomaly (as_anomaly=True).""" # can only be used for hourly sampling rate assert sampling == "1H" @@ -263,46 +434,44 @@ class ClimateFIRFilter: data = data.sel(**sel_opts) # create unity xarray in monthly resolution with sampling point in mid of each month - monthly = self.create_unity_array(data, time_dim) * np.nan + monthly = self.create_monthly_unity_array(data, time_dim) * np.nan - seasonal_hourly_means = {} - - for month in data.groupby(f"{time_dim}.month").groups.keys(): - # select each month - single_month_data = data.sel({time_dim: (data[f"{time_dim}.month"] == month)}) - hourly_mean = single_month_data.groupby(f"{time_dim}.hour").mean() - if as_anomaly is True: - hourly_mean = hourly_mean - hourly_mean.mean("hour") - seasonal_hourly_means[month] = hourly_mean + # calculate for each hour in each month a separate mean value + seasonal_hourly_means = self._compute_hourly_mean_per_month(data, time_dim, as_anomaly) + # create seasonal cycles of these hourly averages seasonal_coll = [] for hour in data.groupby(f"{time_dim}.hour").groups.keys(): - h_coll = monthly.__deepcopy__() - for month in seasonal_hourly_means.keys(): - hourly_mean_single_month = seasonal_hourly_means[month].sel(hour=hour, drop=True) - h_coll = xr.where((h_coll[f"{time_dim}.month"] == month), - hourly_mean_single_month, - h_coll) - h_coll = h_coll.resample({time_dim: sampling}).interpolate() - h_coll = h_coll.sel({time_dim: (h_coll[f"{time_dim}.hour"] == hour)}) - seasonal_coll.append(h_coll) - hourly = xr.concat(seasonal_coll, time_dim).sortby(time_dim).resample({time_dim: sampling}).interpolate() + mean_single_hour = self._create_seasonal_cycle_of_single_hour_mean(monthly, seasonal_hourly_means, hour, + time_dim, sampling) + seasonal_coll.append(mean_single_hour) + # combine all cycles in a common data array + hourly = xr.concat(seasonal_coll, time_dim).sortby(time_dim).resample({time_dim: sampling}).interpolate() return hourly @staticmethod - def extend_apriori(data, apriori, time_dim, sampling="1d"): + def extend_apriori(data: xr.DataArray, apriori: xr.DataArray, time_dim: str, sampling: str = "1d", + station_name: str = None) -> xr.DataArray: """ - Extend time range of apriori information. - - This method may not working properly if length of apriori is less then one year. + Extend time range of apriori information to span a longer period as data (or at least of equal length). This + method may not working properly if length of apriori contains data from less then one year. + + :param data: data to get time range of which apriori should span in minimum + :param apriori: data that is adjusted. It is assumed that this data varies in the course of the year but is same + for the same day in different years. Otherwise this method will introduce some unintended artefacts in the + apriori data. + :param time_dim: name of temporal dimension + :param sampling: sampling of data (e.g. "1m", "1d", default "1d") + :param station_name: name to use for logging message (default None) + :returns: array which adjusted temporal coverage derived from apriori """ dates = data.coords[time_dim].values td_type = {"1d": "D", "1H": "h"}.get(sampling) # apriori starts after data if dates[0] < apriori.coords[time_dim].values[0]: - logging.debug(f"{data.coords['Stations'].values[0]}: apriori starts after data") + logging.debug(f"{station_name}: apriori starts after data") # add difference in full years date_diff = abs(dates[0] - apriori.coords[time_dim].values[0]).astype("timedelta64[D]") @@ -323,7 +492,7 @@ class ClimateFIRFilter: # apriori ends before data if dates[-1] + np.timedelta64(365, "D") > apriori.coords[time_dim].values[-1]: - logging.debug(f"{data.coords['Stations'].values[0]}: apriori ends before data") + logging.debug(f"{station_name}: apriori ends before data") # add difference in full years + 1 year (because apriori is used as future estimate) date_diff = abs(dates[-1] - apriori.coords[time_dim].values[-1]).astype("timedelta64[D]") @@ -344,24 +513,172 @@ class ClimateFIRFilter: return apriori + def combine_observation_and_apriori(self, data: xr.DataArray, apriori: xr.DataArray, time_dim: str, new_dim: str, + extend_length_history: int, extend_length_future: int, + extend_length_separator: int = 0) -> xr.DataArray: + """ + Combine historical data / observations ("data") and climatological statistics ("apriori"). Historical data are + used on interval [t0 - extend_length_history, t0] and apriori is used on [t0 + 1, t0 + extend_length_future]. If + indicated by the extend_length_seperator, it is possible to shift end of history interval and start of apriori + interval by given number of time steps. + + :param data: historical data for past values, must contain dimensions time_dim and var_dim and might also have + a new_dim dimension + :param apriori: climatological estimate for future values, must contain dimensions time_dim and var_dim, but + can also have dimension new_dim + :param time_dim: name of temporal dimension + :param new_dim: name of new dim on which data is combined along + :param extend_length_history: number of time steps to use from data + :param extend_length_future: number of time steps to use from apriori (minus 1) + :param extend_length_separator: position of last history value to use (default 0), this position indicates the + last value that is used from data (followed by values from apriori). In other words, end of history + interval and start of apriori interval are shifted by this value from t0 (positive or negative). + :returns: combined data array + """ + # check if shift indicated by extend_length_seperator is inside the outer interval limits + # assert (extend_length_separator > -extend_length_history) and (extend_length_separator < extend_length_future) + + # prepare historical data / observation + if new_dim not in data.coords: + history = self._shift_data(data, range(int(-extend_length_history), extend_length_separator + 1), + time_dim, new_dim) + else: + history = data.sel({new_dim: slice(int(-extend_length_history), extend_length_separator)}) + # prepare climatological statistics + if new_dim not in apriori.coords: + future = self._shift_data(apriori, range(extend_length_separator + 1, + extend_length_separator + extend_length_future), + time_dim, new_dim) + else: + future = apriori.sel({new_dim: slice(extend_length_separator + 1, + extend_length_separator + extend_length_future)}) + + # combine historical data [t0-length,t0+sep] and climatological statistics [t0+sep+1,t0+length] + filter_input_data = xr.concat([history.dropna(time_dim), future], dim=new_dim, join="left") + return filter_input_data + + def create_visualization(self, filtered, data, filter_input_data, plot_dates, time_dim, new_dim, sampling, + extend_length_history, extend_length_future, minimum_length, h, + variable_name, extend_length_opts=None): # pragma: no cover + plot_data = [] + extend_length_opts = 0 if extend_length_opts is None else extend_length_opts + for viz_date in set(plot_dates).intersection(filtered.coords[time_dim].values): + try: + td_type = {"1d": "D", "1H": "h"}.get(sampling) + t_minus = viz_date + np.timedelta64(int(-extend_length_history), td_type) + t_plus = viz_date + np.timedelta64(int(extend_length_future + extend_length_opts), td_type) + if new_dim not in data.coords: + tmp_filter_data = self._shift_data(data.sel({time_dim: slice(t_minus, t_plus)}), + range(int(-extend_length_history), + int(extend_length_future + extend_length_opts)), + time_dim, + new_dim).sel({time_dim: viz_date}) + else: + tmp_filter_data = None + valid_range = range(int((len(h) + 1) / 2) if minimum_length is None else minimum_length, + extend_length_opts + 1) + plot_data.append({"t0": viz_date, + "var": variable_name, + "filter_input": filter_input_data.sel({time_dim: viz_date}), + "filter_input_nc": tmp_filter_data, + "valid_range": valid_range, + "time_range": data.sel( + {time_dim: slice(t_minus, t_plus - np.timedelta64(1, td_type))}).coords[ + time_dim].values, + "h": h, + "new_dim": new_dim}) + except: + pass + return plot_data + + @staticmethod + def _get_year_interval(data: xr.DataArray, time_dim: str) -> Tuple[int, int]: + """ + Get year of start and end date of given data. + + :param data: data to extract dates from + :param time_dim: name of temporal axis + :returns: two-element tuple with start and end + """ + start = pd.to_datetime(data.coords[time_dim].min().values).year + end = pd.to_datetime(data.coords[time_dim].max().values).year + return start, end + + @staticmethod + def _calculate_filter_coefficients(window: Union[str, tuple], order: Union[int, tuple], cutoff_high: float, + fs: float) -> np.array: + """ + Calculate filter coefficients for moving window using scipy's signal package for common filter types and local + method firwin_kzf for Kolmogorov Zurbenko filter (kzf). The filter is a low-pass filter. + + :param window: name of the window type which is either a string with the window's name or a tuple containing the + name but also some parameters (e.g. `("kaiser", 5)`) + :param order: order of the filter to create as int or parameters m and k of kzf + :param cutoff_high: cutoff frequency to use for low-pass filter in frequency of fs + :param fs: sampling frequency of time series + """ + if window == "kzf": + h = firwin_kzf(*order) + else: + h = signal.firwin(order, cutoff_high, pass_zero="lowpass", fs=fs, window=window) + return h + + @staticmethod + def _trim_data_to_minimum_length(data: xr.DataArray, extend_length_history: int, dim: str, + minimum_length: int = None, extend_length_opts: int = 0) -> xr.DataArray: + """ + Trim data along given axis between either -minimum_length (if given) or -extend_length_history and + extend_length_opts (which is default set to 0). + + :param data: data to trim + :param extend_length_history: start number for trim range (transformed to negative), only used if parameter + minimum_length is not provided + :param dim: dim to apply trim on + :param minimum_length: start number for trim range (transformed to negative), preferably used (default None) + :param extend_length_opts: number to use in "future" + :returns: trimmed data + """ + if minimum_length is None: + return data.sel({dim: slice(-extend_length_history, extend_length_opts)}, drop=True) + else: + return data.sel({dim: slice(-minimum_length, extend_length_opts)}, drop=True) + + @staticmethod + def _create_full_filter_result_array(template_array: xr.DataArray, result_array: xr.DataArray, new_dim: str, + station_name: str = None) -> xr.DataArray: + """ + Create result filter array with same shape line given template data (should be the original input data before + filtering the data). All gaps are filled by nans. + + :param template_array: this array is used as template for shape and ordering of dims + :param result_array: array with data that are filled into template + :param new_dim: new dimension which is shifted/appended to/at the end (if present or not) + :param station_name: string that is attached to logging (default None) + """ + logging.debug(f"{station_name}: create res_full") + new_coords = {**{k: template_array.coords[k].values for k in template_array.coords if k != new_dim}, + new_dim: result_array.coords[new_dim]} + dims = [*template_array.dims, new_dim] if new_dim not in template_array.dims else template_array.dims + result_array = result_array.transpose(*dims) + return result_array.broadcast_like(xr.DataArray(dims=dims, coords=new_coords)) + @TimeTrackingWrapper def clim_filter(self, data, fs, cutoff_high, order, apriori=None, sel_opts=None, sampling="1d", time_dim="datetime", var_dim="variables", window: Union[str, Tuple] = "hamming", - minimum_length=None, new_dim="window", plot_dates=None): + minimum_length=None, new_dim="window", plot_dates=None, station_name=None, + extend_length_opts: Union[dict, int] = None): - logging.debug(f"{data.coords['Stations'].values[0]}: extend apriori") + logging.debug(f"{station_name}: extend apriori") + extend_opts = extend_length_opts if extend_length_opts is not None else {} # calculate apriori information from data if not given and extend its range if not sufficient long enough if apriori is None: - apriori = self.create_monthly_mean(data, sel_opts=sel_opts, sampling=sampling, time_dim=time_dim) + apriori = self.create_monthly_mean(data, time_dim, sel_opts=sel_opts, sampling=sampling) apriori = apriori.astype(data.dtype) - apriori = self.extend_apriori(data, apriori, time_dim, sampling) + apriori = self.extend_apriori(data, apriori, time_dim, sampling, station_name=station_name) # calculate FIR filter coefficients - if window == "kzf": - h = firwin_kzf(*order) - else: - h = signal.firwin(order, cutoff_high, pass_zero="lowpass", fs=fs, window=window) + h = self._calculate_filter_coefficients(window, order, cutoff_high, fs) length = len(h) # use filter length if no minimum is given, otherwise use minimum + half filter length for extension @@ -378,30 +695,28 @@ class ClimateFIRFilter: coll = [] for var in reversed(data.coords[var_dim].values): - logging.info(f"{data.coords['Stations'].values[0]} ({var}): sel data") + logging.info(f"{station_name} ({var}): sel data") - _start = pd.to_datetime(data.coords[time_dim].min().values).year - _end = pd.to_datetime(data.coords[time_dim].max().values).year + _start, _end = self._get_year_interval(data, time_dim) + extend_opts_var = extend_opts.get(var, 0) if isinstance(extend_opts, dict) else extend_opts filt_coll = [] for _year in range(_start, _end + 1): - logging.debug(f"{data.coords['Stations'].values[0]} ({var}): year={_year}") + logging.debug(f"{station_name} ({var}): year={_year}") - time_slice = self._create_time_range_extend(_year, sampling, extend_length_history) + # select observations and apriori data + time_slice = self._create_time_range_extend( + _year, sampling, max(extend_length_history, extend_length_future + extend_opts_var)) d = data.sel({var_dim: [var], time_dim: time_slice}) a = apriori.sel({var_dim: [var], time_dim: time_slice}) if len(d.coords[time_dim]) == 0: # no data at all for this year continue # combine historical data / observation [t0-length,t0] and climatological statistics [t0+1,t0+length] - if new_dim not in d.coords: - history = self._shift_data(d, range(int(-extend_length_history), 1), time_dim, var_dim, new_dim) - else: - history = d.sel({new_dim: slice(int(-extend_length_history), 0)}) - if new_dim not in a.coords: - future = self._shift_data(a, range(1, extend_length_future), time_dim, var_dim, new_dim) - else: - future = a.sel({new_dim: slice(1, extend_length_future)}) - filter_input_data = xr.concat([history.dropna(time_dim), future], dim=new_dim, join="left") + filter_input_data = self.combine_observation_and_apriori(d, a, time_dim, new_dim, extend_length_history, + extend_length_future, + extend_length_separator=extend_opts_var) + + # select only data for current year try: filter_input_data = filter_input_data.sel({time_dim: str(_year)}) except KeyError: # no valid data for this year @@ -409,70 +724,45 @@ class ClimateFIRFilter: if len(filter_input_data.coords[time_dim]) == 0: # no valid data for this year continue - logging.debug(f"{data.coords['Stations'].values[0]} ({var}): start filter convolve") - with TimeTracking(name=f"{data.coords['Stations'].values[0]} ({var}): filter convolve", - logging_level=logging.DEBUG): + # apply filter + logging.debug(f"{station_name} ({var}): start filter convolve") + with TimeTracking(name=f"{station_name} ({var}): filter convolve", logging_level=logging.DEBUG): filt = xr.apply_ufunc(fir_filter_convolve, filter_input_data, - input_core_dims=[[new_dim]], - output_core_dims=[[new_dim]], - vectorize=True, - kwargs={"h": h}, - output_dtypes=[d.dtype]) - - if minimum_length is None: - filt_coll.append(filt.sel({new_dim: slice(-extend_length_history, 0)}, drop=True)) - else: - filt_coll.append(filt.sel({new_dim: slice(-minimum_length, 0)}, drop=True)) + input_core_dims=[[new_dim]], output_core_dims=[[new_dim]], + vectorize=True, kwargs={"h": h}, output_dtypes=[d.dtype]) + + # trim data if required + trimmed = self._trim_data_to_minimum_length(filt, extend_length_history, new_dim, minimum_length, + extend_length_opts=extend_opts_var) + filt_coll.append(trimmed) # visualization - for viz_date in set(plot_dates).intersection(filt.coords[time_dim].values): - try: - td_type = {"1d": "D", "1H": "h"}.get(sampling) - t_minus = viz_date + np.timedelta64(int(-extend_length_history), td_type) - t_plus = viz_date + np.timedelta64(int(extend_length_future), td_type) - if new_dim not in d.coords: - tmp_filter_data = self._shift_data(d.sel({time_dim: slice(t_minus, t_plus)}), - range(int(-extend_length_history), - int(extend_length_future)), - time_dim, var_dim, new_dim).sel({time_dim: viz_date}) - else: - # tmp_filter_data = d.sel({time_dim: viz_date, - # new_dim: slice(int(-extend_length_history), int(extend_length_future))}) - tmp_filter_data = None - valid_range = range(int((length + 1) / 2) if minimum_length is None else minimum_length, 1) - plot_data.append({"t0": viz_date, - "var": var, - "filter_input": filter_input_data.sel({time_dim: viz_date}), - "filter_input_nc": tmp_filter_data, - "valid_range": valid_range, - "time_range": d.sel( - {time_dim: slice(t_minus, t_plus - np.timedelta64(1, td_type))}).coords[ - time_dim].values, - "h": h, - "new_dim": new_dim}) - except: - pass + plot_data.extend(self.create_visualization(filt, d, filter_input_data, plot_dates, time_dim, new_dim, + sampling, extend_length_history, extend_length_future, + minimum_length, h, var, extend_opts_var)) # collect all filter results coll.append(xr.concat(filt_coll, time_dim)) gc.collect() - logging.debug(f"{data.coords['Stations'].values[0]}: concat all variables") + # concat all variables + logging.debug(f"{station_name}: concat all variables") res = xr.concat(coll, var_dim) - # create result array with same shape like input data, gabs are filled by nans - logging.debug(f"{data.coords['Stations'].values[0]}: create res_full") - - new_coords = {**{k: data.coords[k].values for k in data.coords if k != new_dim}, new_dim: res.coords[new_dim]} - dims = [*data.dims, new_dim] if new_dim not in data.dims else data.dims - res = res.transpose(*dims) - # res_full = xr.DataArray(dims=dims, coords=new_coords) - # res_full.loc[res.coords] = res - # res_full.compute() - res_full = res.broadcast_like(xr.DataArray(dims=dims, coords=new_coords)) + + # create result array with same shape like input data, gaps are filled by nans + res_full = self._create_full_filter_result_array(data, res, new_dim, station_name) return res_full, h, apriori, plot_data @staticmethod - def _create_time_range_extend(year, sampling, extend_length): + def _create_time_range_extend(year: int, sampling: str, extend_length: int) -> slice: + """ + Create a slice object for given year plus extend_length in sampling resolution. + + :param year: year to create time range for + :param sampling: sampling of time range + :param extend_length: number of time steps to extend out of given year + :returns: slice object with time range + """ td_type = {"1d": "D", "1H": "h"}.get(sampling) delta = np.timedelta64(extend_length + 1, td_type) start = np.datetime64(f"{year}-01-01") - delta @@ -480,7 +770,14 @@ class ClimateFIRFilter: return slice(start, end) @staticmethod - def _create_tmp_dimension(data): + def _create_tmp_dimension(data: xr.DataArray) -> str: + """ + Create a tmp dimension with name 'window' preferably. If name is already part of one dimensions, tmp dimension + name is multiplied by itself until not present in dims. Method will raise ValueError after 10 tries. + + :param data: data array to create a new tmp dimension for with unique name + :returns: valid name for a tmp dimension (preferably 'window') + """ new_dim = "window" count = 0 while new_dim in data.dims: @@ -490,33 +787,41 @@ class ClimateFIRFilter: raise ValueError("Could not create new dimension.") return new_dim - def _shift_data(self, data, index_value, time_dim, squeeze_dim, new_dim): + def _shift_data(self, data: xr.DataArray, index_value: range, time_dim: str, new_dim: str) -> xr.DataArray: + """ + Shift data multiple times to create history or future along dimension new_dim for each time step. + + :param data: data set to shift + :param index_value: range of integers to span history and/or future + :param time_dim: name of temporal dimension that should be shifted + :param new_dim: name of dimension create by data shift + :return: shifted data + """ coll = [] for i in index_value: coll.append(data.shift({time_dim: -i})) - new_ind = self.create_index_array(new_dim, index_value, squeeze_dim) + new_ind = self.create_index_array(new_dim, index_value) return xr.concat(coll, dim=new_ind) @staticmethod - def create_index_array(index_name: str, index_value, squeeze_dim: str): + def create_index_array(index_name: str, index_value: range): + """ + Create index array from a range object to use as index of a data array. + + :param index_name: name of the index dimension + :param index_value: range of values to use as indexes + :returns: index array for given range of values + """ ind = pd.DataFrame({'val': index_value}, index=index_value) - res = xr.Dataset.from_dataframe(ind).to_array(squeeze_dim).rename({'index': index_name}).squeeze( - dim=squeeze_dim, - drop=True) + tmp_dim = index_name + "tmp" + res = xr.Dataset.from_dataframe(ind).to_array(tmp_dim).rename({'index': index_name}) + res = res.squeeze(dim=tmp_dim, drop=True) res.name = index_name return res - @property - def filter_coefficients(self): - return self._h - - @property - def filtered_data(self): - return self._filtered - @property def apriori_data(self): - return self._apriori + return self._apriori_list @property def initial_apriori_data(self): @@ -767,17 +1072,19 @@ class KolmogorovZurbenkoFilterMovingWindow(KolmogorovZurbenkoBaseClass): raise ValueError -def firwin_kzf(m, k): +def firwin_kzf(m: int, k: int) -> np.array: + """Calculate weights of window for Kolmogorov Zurbenko filter.""" + m, k = int(m), int(k) coef = np.ones(m) for i in range(1, k): t = np.zeros((m, m + i * (m - 1))) for km in range(m): t[km, km:km + coef.size] = coef coef = np.sum(t, axis=0) - return coef / m ** k + return coef / (m ** k) -def omega_null_kzf(m, k, alpha=0.5): +def omega_null_kzf(m: int, k: int, alpha: float = 0.5) -> float: a = np.sqrt(6) / np.pi b = 1 / (2 * np.array(k)) c = 1 - alpha ** b @@ -785,5 +1092,6 @@ def omega_null_kzf(m, k, alpha=0.5): return a * np.sqrt(c / d) -def filter_width_kzf(m, k): +def filter_width_kzf(m: int, k: int) -> int: + """Returns window width of the Kolmorogov Zurbenko filter.""" return k * (m - 1) + 1 diff --git a/mlair/helpers/helpers.py b/mlair/helpers/helpers.py index 2f25972cf8490f5dbe0eaebd53f5b530a34d7914..8890ae3c9fe389dfa64bed4750df2944c4cacc61 100644 --- a/mlair/helpers/helpers.py +++ b/mlair/helpers/helpers.py @@ -8,9 +8,6 @@ import json import math import os -import sys - - import numpy as np import pandas as pd import xarray as xr diff --git a/mlair/helpers/join.py b/mlair/helpers/join.py index 93cb0e7b1b34d1ebc13b914ac9626fb4466a7201..67591b29a4e4bcc8b3083869825aed09ebebaf58 100644 --- a/mlair/helpers/join.py +++ b/mlair/helpers/join.py @@ -43,6 +43,9 @@ def download_join(station_name: Union[str, List[str]], stat_var: dict, station_t # make sure station_name parameter is a list station_name = helpers.to_list(station_name) + # also ensure that given data_origin dict is no reference + data_origin = None if data_origin is None else {k: v for (k, v) in data_origin.items()} + # get data connection settings join_url_base, headers = join_settings(sampling) diff --git a/mlair/helpers/statistics.py b/mlair/helpers/statistics.py index aa89da0ea66e263d076af9abd578ba125c260bec..33dc05b6612b436665efb3b3960aed07a00faf52 100644 --- a/mlair/helpers/statistics.py +++ b/mlair/helpers/statistics.py @@ -10,6 +10,7 @@ import xarray as xr import pandas as pd from typing import Union, Tuple, Dict, List import itertools +import dask.array as da Data = Union[xr.DataArray, pd.DataFrame] @@ -202,7 +203,12 @@ def log_apply(data: Data, mean: Data, std: Data) -> Data: def mean_squared_error(a, b, dim=None): """Calculate mean squared error.""" - return np.square(a - b).mean(dim) + try: + return da.square(a - b).mean(dim, skipna=True) + except TypeError: + return da.square(a - b).mean(dim) + except Exception: + return np.square(a - b).mean(dim) def mean_absolute_error(a, b, dim=None): @@ -218,6 +224,16 @@ def calculate_error_metrics(a, b, dim): n = (a - b).notnull().sum(dim) return {"mse": mse, "rmse": rmse, "mae": mae, "n": n} +def skill_score_based_on_mse(data: xr.DataArray, obs_name: str, pred_name: str, ref_name: str, + aggregation_dim: str = "index", competitor_dim: str = "type") -> xr.DataArray: + obs = data.sel({competitor_dim: obs_name}) + pred = data.sel({competitor_dim: pred_name}) + ref = data.sel({competitor_dim: ref_name}) + ss = 1 - mean_squared_error(obs, pred, dim=aggregation_dim) / mean_squared_error(obs, ref, dim=aggregation_dim) + return ss + + + class SkillScores: r""" @@ -284,7 +300,7 @@ class SkillScores: def get_model_name_combinations(self): """Return all combinations of two models as tuple and string.""" combinations = list(itertools.combinations(self.models, 2)) - combination_strings = [f"{first}-{second}" for (first, second) in combinations] + combination_strings = [f"{first} - {second}" for (first, second) in combinations] return combinations, combination_strings def skill_scores(self) -> [pd.DataFrame, pd.DataFrame]: @@ -361,7 +377,7 @@ class SkillScores: **kwargs) def general_skill_score(self, data: Data, forecast_name: str, reference_name: str, - observation_name: str = None) -> np.ndarray: + observation_name: str = None, dim: str = "index") -> np.ndarray: r""" Calculate general skill score based on mean squared error. @@ -374,12 +390,12 @@ class SkillScores: """ if observation_name is None: observation_name = self.observation_name - data = data.dropna("index") + data = data.dropna(dim) observation = data.sel(type=observation_name) forecast = data.sel(type=forecast_name) reference = data.sel(type=reference_name) mse = mean_squared_error - skill_score = 1 - mse(observation, forecast) / mse(observation, reference) + skill_score = 1 - mse(observation, forecast, dim=dim) / mse(observation, reference, dim=dim) return skill_score.values def get_count(self, data: Data, forecast_name: str, reference_name: str, diff --git a/mlair/helpers/testing.py b/mlair/helpers/testing.py index abb50883c7af49a0c1571d99f737e310abff9b13..e727d9b50308d339af79f5c5b82b592af6e91921 100644 --- a/mlair/helpers/testing.py +++ b/mlair/helpers/testing.py @@ -1,10 +1,13 @@ """Helper functions that are used to simplify testing.""" import re from typing import Union, Pattern, List +import inspect import numpy as np import xarray as xr +from mlair.helpers.helpers import remove_items, to_list + class PyTestRegex: r""" @@ -86,3 +89,49 @@ def PyTestAllEqual(check_list: List): return self._check_all_equal() return PyTestAllEqualClass(check_list).is_true() + + +def get_all_args(*args, remove=None, add=None): + res = [] + for a in args: + arg_spec = inspect.getfullargspec(a) + res.extend(arg_spec.args) + res.extend(arg_spec.kwonlyargs) + res = sorted(list(set(res))) + if remove is not None: + res = remove_items(res, remove) + if add is not None: + res += to_list(add) + return res + + +def test_nested_equality(obj1, obj2): + + try: + print(f"check type {type(obj1)} and {type(obj2)}") + assert type(obj1) == type(obj2) + + if isinstance(obj1, (tuple, list)): + print(f"check length {len(obj1)} and {len(obj2)}") + assert len(obj1) == len(obj2) + for pos in range(len(obj1)): + print(f"check pos {obj1[pos]} and {obj2[pos]}") + assert test_nested_equality(obj1[pos], obj2[pos]) is True + elif isinstance(obj1, dict): + print(f"check keys {obj1.keys()} and {obj2.keys()}") + assert sorted(obj1.keys()) == sorted(obj2.keys()) + for k in obj1.keys(): + print(f"check pos {obj1[k]} and {obj2[k]}") + assert test_nested_equality(obj1[k], obj2[k]) is True + elif isinstance(obj1, xr.DataArray): + print(f"check xr {obj1} and {obj2}") + assert xr.testing.assert_equal(obj1, obj2) is None + elif isinstance(obj1, np.ndarray): + print(f"check np {obj1} and {obj2}") + assert np.testing.assert_array_equal(obj1, obj2) is None + else: + print(f"check equal {obj1} and {obj2}") + assert obj1 == obj2 + except AssertionError: + return False + return True diff --git a/mlair/helpers/time_tracking.py b/mlair/helpers/time_tracking.py index 3105ebcd04406b7d449ba312bd3af46f83e3a716..5df695b9eee5352152c3189111bacf2fe05a2cb3 100644 --- a/mlair/helpers/time_tracking.py +++ b/mlair/helpers/time_tracking.py @@ -41,7 +41,10 @@ class TimeTrackingWrapper: def __get__(self, instance, cls): """Create bound method object and supply self argument to the decorated method.""" - return types.MethodType(self, instance) + if instance is None: + return self + else: + return types.MethodType(self, instance) class TimeTracking(object): @@ -68,12 +71,13 @@ class TimeTracking(object): The only disadvantage of the latter implementation is, that the duration is logged but not returned. """ - def __init__(self, start=True, name="undefined job", logging_level=logging.INFO): + def __init__(self, start=True, name="undefined job", logging_level=logging.INFO, log_on_enter=False): """Construct time tracking and start if enabled.""" self.start = None self.end = None self._name = name self._logging = {logging.INFO: logging.info, logging.DEBUG: logging.debug}.get(logging_level, logging.info) + self._log_on_enter = log_on_enter if start: self._start() @@ -124,6 +128,7 @@ class TimeTracking(object): def __enter__(self): """Context manager.""" + self._logging(f"start {self._name}") if self._log_on_enter is True else None return self def __exit__(self, exc_type, exc_val, exc_tb) -> None: diff --git a/mlair/model_modules/keras_extensions.py b/mlair/model_modules/keras_extensions.py index d890e7b0ff3beea812d8fc7766433a84d65a1ebe..8b99acd0f5723d3b00ec1bd0098712753da21b52 100644 --- a/mlair/model_modules/keras_extensions.py +++ b/mlair/model_modules/keras_extensions.py @@ -3,6 +3,7 @@ __author__ = 'Lukas Leufen, Felix Kleinert' __date__ = '2020-01-31' +import copy import logging import math import pickle @@ -199,12 +200,18 @@ class ModelCheckpointAdvanced(ModelCheckpoint): if self.verbose > 0: # pragma: no branch print('\nEpoch %05d: save to %s' % (epoch + 1, file_path)) with open(file_path, "wb") as f: - pickle.dump(callback["callback"], f) + c = copy.copy(callback["callback"]) + if hasattr(c, "model"): + c.model = None + pickle.dump(c, f) else: with open(file_path, "wb") as f: if self.verbose > 0: # pragma: no branch print('\nEpoch %05d: save to %s' % (epoch + 1, file_path)) - pickle.dump(callback["callback"], f) + c = copy.copy(callback["callback"]) + if hasattr(c, "model"): + c.model = None + pickle.dump(c, f) clbk_type = TypedDict("clbk_type", {"name": str, str: Callback, "path": str}) @@ -346,6 +353,8 @@ class CallbackHandler: for pos, callback in enumerate(self.__callbacks): path = callback["path"] clb = pickle.load(open(path, "rb")) + if clb.model is None and hasattr(self._checkpoint, "model"): + clb.model = self._checkpoint.model self._update_callback(pos, clb) def update_checkpoint(self, history_name: str = "hist") -> None: diff --git a/mlair/plotting/abstract_plot_class.py b/mlair/plotting/abstract_plot_class.py index dab45156ac1bbe033ba073e01245ffc8b65ca6b3..7a91c2269ccd03608bcdbe67a634156f55fde91f 100644 --- a/mlair/plotting/abstract_plot_class.py +++ b/mlair/plotting/abstract_plot_class.py @@ -59,7 +59,7 @@ class AbstractPlotClass: if not os.path.exists(plot_folder): os.makedirs(plot_folder) self.plot_folder = plot_folder - self.plot_name = plot_name + self.plot_name = plot_name.replace("/", "_") if plot_name is not None else plot_name self.resolution = resolution if rc_params is None: rc_params = {'axes.labelsize': 'large', @@ -71,6 +71,9 @@ class AbstractPlotClass: self.rc_params = rc_params self._update_rc_params() + def __del__(self): + plt.close('all') + def _plot(self, *args): """Abstract plot class needs to be implemented in inheritance.""" raise NotImplementedError diff --git a/mlair/plotting/data_insight_plotting.py b/mlair/plotting/data_insight_plotting.py index 6180493741c030d5dfdfcfa8972035619632c8aa..1eee96623d4fed6fcfb23fd1438a954a4aca230f 100644 --- a/mlair/plotting/data_insight_plotting.py +++ b/mlair/plotting/data_insight_plotting.py @@ -14,6 +14,7 @@ import numpy as np import pandas as pd import xarray as xr import matplotlib +# matplotlib.use("Agg") from matplotlib import lines as mlines, pyplot as plt, patches as mpatches, dates as mdates from astropy.timeseries import LombScargle @@ -495,7 +496,7 @@ class PlotDataHistogram(AbstractPlotClass): # pragma: no cover def _get_inputs_targets(gens, dim): k = list(gens.keys())[0] gen = gens[k][0] - inputs = to_list(gen.get_X(as_numpy=False)[0].coords[dim].values.tolist()) + inputs = list(set([y for x in to_list(gen.get_X(as_numpy=False)) for y in x.coords[dim].values.tolist()])) targets = to_list(gen.get_Y(as_numpy=False).coords[dim].values.tolist()) n_branches = len(gen.get_X(as_numpy=False)) return inputs, targets, n_branches @@ -516,7 +517,7 @@ class PlotDataHistogram(AbstractPlotClass): # pragma: no cover w = min(abs(f(gen).coords[self.window_dim].values)) data = f(gen).sel({self.window_dim: w}) res, _, g_edges = f_proc_hist(data, variables, n_bins, self.variables_dim) - for var in variables: + for var in res.keys(): b = tmp_bins.get(var, []) b.append(res[var]) tmp_bins[var] = b @@ -529,7 +530,7 @@ class PlotDataHistogram(AbstractPlotClass): # pragma: no cover bins = {} edges = {} interval_width = {} - for var in variables: + for var in tmp_bins.keys(): bin_edges = np.linspace(start[var], end[var], n_bins + 1) interval_width[var] = bin_edges[1] - bin_edges[0] for i, e in enumerate(tmp_bins[var]): @@ -632,7 +633,10 @@ class PlotPeriodogram(AbstractPlotClass): # pragma: no cover self._plot_total(raw=True) self._plot_total(raw=False) if multiple > 1: - self._plot_difference(label_names) + self._plot_difference(label_names, plot_name_add="_last") + self._prepare_pgram(generator, pos, multiple, use_multiprocessing=use_multiprocessing, + use_last_input_value=False) + self._plot_difference(label_names, plot_name_add="_first") @staticmethod def _has_filter_dimension(g, pos): @@ -649,7 +653,7 @@ class PlotPeriodogram(AbstractPlotClass): # pragma: no cover return check_data.coords[filter_dim].shape[0], check_data.coords[filter_dim].values.tolist() @TimeTrackingWrapper - def _prepare_pgram(self, generator, pos, multiple=1, use_multiprocessing=False): + def _prepare_pgram(self, generator, pos, multiple=1, use_multiprocessing=False, use_last_input_value=True): """ Create periodogram data. """ @@ -663,7 +667,8 @@ class PlotPeriodogram(AbstractPlotClass): # pragma: no cover plot_data_raw_single = dict() plot_data_mean_single = dict() self.f_index = np.logspace(-3, 0 if self._sampling == "daily" else np.log10(24), 1000) - raw_data_single = self._prepare_pgram_parallel_gen(generator, m, pos, use_multiprocessing) + raw_data_single = self._prepare_pgram_parallel_gen(generator, m, pos, use_multiprocessing, + use_last_input_value=use_last_input_value) for var in raw_data_single.keys(): pgram_com = [] pgram_mean = 0 @@ -705,6 +710,7 @@ class PlotPeriodogram(AbstractPlotClass): # pragma: no cover for i, p in enumerate(output): res.append(p.get()) pool.close() + pool.join() else: # serial solution for var in d[self.variables_dim].values: res.append(f_proc(var, d.loc[{self.variables_dim: var}].squeeze().dropna(self.time_dim))) @@ -715,7 +721,7 @@ class PlotPeriodogram(AbstractPlotClass): # pragma: no cover raw_data_single[var_str] = raw_data_single[var_str] + [(f, pgram)] return raw_data_single - def _prepare_pgram_parallel_gen(self, generator, m, pos, use_multiprocessing): + def _prepare_pgram_parallel_gen(self, generator, m, pos, use_multiprocessing, use_last_input_value=True): """Implementation of data preprocessing using parallel generator element processing.""" raw_data_single = dict() res = [] @@ -723,14 +729,16 @@ class PlotPeriodogram(AbstractPlotClass): # pragma: no cover pool = multiprocessing.Pool( min([psutil.cpu_count(logical=False), len(generator), 16])) # use only physical cpus output = [ - pool.apply_async(f_proc_2, args=(g, m, pos, self.variables_dim, self.time_dim, self.f_index)) + pool.apply_async(f_proc_2, args=(g, m, pos, self.variables_dim, self.time_dim, self.f_index, + use_last_input_value)) for g in generator] for i, p in enumerate(output): res.append(p.get()) pool.close() + pool.join() else: for g in generator: - res.append(f_proc_2(g, m, pos, self.variables_dim, self.time_dim, self.f_index)) + res.append(f_proc_2(g, m, pos, self.variables_dim, self.time_dim, self.f_index, use_last_input_value)) for res_dict in res: for k, v in res_dict.items(): if k not in raw_data_single.keys(): @@ -816,8 +824,8 @@ class PlotPeriodogram(AbstractPlotClass): # pragma: no cover pdf_pages.close() plt.close('all') - def _plot_difference(self, label_names): - plot_name = f"{self.plot_name}_{self._sampling}_{self._add_text}_filter.pdf" + def _plot_difference(self, label_names, plot_name_add = ""): + plot_name = f"{self.plot_name}_{self._sampling}_{self._add_text}_filter{plot_name_add}.pdf" plot_path = os.path.join(os.path.abspath(self.plot_folder), plot_name) logging.info(f"... plotting {plot_name}") pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path) @@ -846,35 +854,59 @@ class PlotPeriodogram(AbstractPlotClass): # pragma: no cover plt.close('all') -def f_proc(var, d_var, f_index, time_dim="datetime"): # pragma: no cover +def f_proc(var, d_var, f_index, time_dim="datetime", use_last_value=True): # pragma: no cover var_str = str(var) t = (d_var[time_dim] - d_var[time_dim][0]).astype("timedelta64[h]").values / np.timedelta64(1, "D") if len(d_var.shape) > 1: # use only max value if dimensions are remaining (e.g. max(window) -> latest value) to_remove = remove_items(d_var.coords.dims, time_dim) for e in to_list(to_remove): - d_var = d_var.sel({e: d_var[e].max()}) + d_var = d_var.sel({e: d_var[e].max() if use_last_value is True else d_var[e].min()}) pgram = LombScargle(t, d_var.values.flatten(), nterms=1, normalization="psd").power(f_index) # f, pgram = LombScargle(t, d_var.values.flatten(), nterms=1, normalization="psd").autopower() return var_str, f_index, pgram -def f_proc_2(g, m, pos, variables_dim, time_dim, f_index): # pragma: no cover +def f_proc_2(g, m, pos, variables_dim, time_dim, f_index, use_last_value): # pragma: no cover + + # load lazy data + id_classes = list(filter(lambda x: "id_class" in x, dir(g))) if pos == 0 else ["id_class"] + for id_cls_name in id_classes: + id_cls = getattr(g, id_cls_name) + if hasattr(id_cls, "lazy"): + id_cls.load_lazy() if id_cls.lazy is True else None + raw_data_single = dict() - if hasattr(g.id_class, "lazy"): - g.id_class.load_lazy() if g.id_class.lazy is True else None - if m == 0: - d = g.id_class._data - else: - gd = g.id_class - filter_sel = {"filter": gd.input_data.coords["filter"][m - 1]} - d = (gd.input_data.sel(filter_sel), gd.target_data) - d = d[pos] if isinstance(d, tuple) else d - for var in d[variables_dim].values: - d_var = d.loc[{variables_dim: var}].squeeze().dropna(time_dim) - var_str, f, pgram = f_proc(var, d_var, f_index) - raw_data_single[var_str] = [(f, pgram)] - if hasattr(g.id_class, "lazy"): - g.id_class.clean_up() if g.id_class.lazy is True else None + for dh in list(filter(lambda x: "unfiltered" not in x, id_classes)): + current_cls = getattr(g, dh) + if m == 0: + d = current_cls._data + if d is None: + window_dim = current_cls.window_dim + history = current_cls.history + last_entry = history.coords[window_dim][-1] + d1 = history.sel({window_dim: last_entry}, drop=True) + label = current_cls.label + first_entry = label.coords[window_dim][0] + d2 = label.sel({window_dim: first_entry}, drop=True) + d = (d1, d2) + else: + filter_sel = {"filter": current_cls.input_data.coords["filter"][m - 1]} + d = (current_cls.input_data.sel(filter_sel), current_cls.target_data) + d = d[pos] if isinstance(d, tuple) else d + for var in d[variables_dim].values: + d_var = d.loc[{variables_dim: var}].squeeze().dropna(time_dim) + var_str, f, pgram = f_proc(var, d_var, f_index, use_last_value=use_last_value) + if var_str not in raw_data_single.keys(): + raw_data_single[var_str] = [(f, pgram)] + else: + raise KeyError(f"There are multiple pgrams for key {var_str}. Please check your data handler.") + + # perform clean up + for id_cls_name in id_classes: + id_cls = getattr(g, id_cls_name) + if hasattr(id_cls, "lazy"): + id_cls.clean_up() if id_cls.lazy is True else None + return raw_data_single @@ -883,13 +915,14 @@ def f_proc_hist(data, variables, n_bins, variables_dim): # pragma: no cover bin_edges = {} interval_width = {} for var in variables: - d = data.sel({variables_dim: var}).squeeze() if len(data.shape) > 1 else data - res[var], bin_edges[var] = np.histogram(d.values, n_bins) - interval_width[var] = bin_edges[var][1] - bin_edges[var][0] + if var in data.coords[variables_dim]: + d = data.sel({variables_dim: var}).squeeze() if len(data.shape) > 1 else data + res[var], bin_edges[var] = np.histogram(d.values, n_bins) + interval_width[var] = bin_edges[var][1] - bin_edges[var][0] return res, interval_width, bin_edges -class PlotClimateFirFilter(AbstractPlotClass): +class PlotClimateFirFilter(AbstractPlotClass): # pragma: no cover """ Plot climate FIR filter components. @@ -938,7 +971,7 @@ class PlotClimateFirFilter(AbstractPlotClass): """Restructure plot data.""" plot_dict = {} new_dim = None - for i, o in enumerate(range(len(data))): + for i in range(len(data)): plot_data = data[i] for p_d in plot_data: var = p_d.get("var") @@ -1108,3 +1141,131 @@ class PlotClimateFirFilter(AbstractPlotClass): file = os.path.join(self.plot_folder, "plot_data.pickle") with open(file, "wb") as f: dill.dump(data, f) + + +class PlotFirFilter(AbstractPlotClass): # pragma: no cover + """ + Plot FIR filter components. + + * Creates a separate folder FIR inside the given plot directory. + * For each station up to 4 examples are shown (1 for each season). + * Each filtered component and its residuum is drawn in a separate plot. + * A filter component plot includes the FIR input and the filter response + * A filter residuum plot include the FIR residuum + """ + + def __init__(self, plot_folder, plot_data, name): + + logging.info(f"start PlotFirFilter for ({name})") + + # adjust default plot parameters + rc_params = { + 'axes.labelsize': 'large', + 'xtick.labelsize': 'large', + 'ytick.labelsize': 'large', + 'legend.fontsize': 'medium', + 'axes.titlesize': 'large'} + if plot_folder is None: + return + + self.style_dict = { + "original": {"color": "darkgrey", "linestyle": "dashed", "label": "original"}, + "apriori": {"color": "darkgrey", "linestyle": "solid", "label": "estimated future"}, + "clim": {"color": "black", "linestyle": "solid", "label": "clim filter", "linewidth": 2}, + "FIR": {"color": "black", "linestyle": "dashed", "label": "ideal filter", "linewidth": 2}, + "valid_area": {"color": "whitesmoke", "label": "valid area"}, + "t0": {"color": "lightgrey", "lw": 6, "label": "$t_0$"} + } + + plot_folder = os.path.join(os.path.abspath(plot_folder), "FIR") + super().__init__(plot_folder, plot_name=None, rc_params=rc_params) + plot_dict = self._prepare_data(plot_data) + self._name = name + self._plot(plot_dict) + self._store_plot_data(plot_data) + + def _prepare_data(self, data): + """Restructure plot data.""" + plot_dict = {} + for i in range(len(data)): # filter component + for j in range(len(data[i])): # t0 counter + plot_data = data[i][j] + t0 = plot_data.get("t0") + filter_input = plot_data.get("filter_input") + filtered = plot_data.get("filtered") + var_dim = plot_data.get("var_dim") + time_dim = plot_data.get("time_dim") + for var in filtered.coords[var_dim].values: + plot_dict_var = plot_dict.get(var, {}) + plot_dict_t0 = plot_dict_var.get(t0, {}) + plot_dict_order = {"filter_input": filter_input.sel({var_dim: var}, drop=True), + "filtered": filtered.sel({var_dim: var}, drop=True), + "time_dim": time_dim} + plot_dict_t0[i] = plot_dict_order + plot_dict_var[t0] = plot_dict_t0 + plot_dict[var] = plot_dict_var + return plot_dict + + def _plot(self, plot_dict): + for var, viz_date_dict in plot_dict.items(): + for it0, t0 in enumerate(viz_date_dict.keys()): + viz_data = viz_date_dict[t0] + try: + for ifilter in sorted(viz_data.keys()): + data = viz_data[ifilter] + filter_input = data["filter_input"] + filtered = data["filtered"] + time_dim = data["time_dim"] + time_axis = filtered.coords[time_dim].values + fig, ax = plt.subplots() + + # plot backgrounds + self._plot_t0(ax, t0) + + # original data + self._plot_data(ax, time_axis, filter_input, style="original") + + # filter response + self._plot_data(ax, time_axis, filtered, style="FIR") + + # set title, legend, and save plot + ax.set_xlim((time_axis[0], time_axis[-1])) + + plt.title(f"Input of Filter ({str(var)})") + plt.legend() + fig.autofmt_xdate() + plt.tight_layout() + self.plot_name = f"FIR_{self._name}_{str(var)}_{it0}_{ifilter}" + self._save() + + # plot residuum + fig, ax = plt.subplots() + self._plot_t0(ax, t0) + self._plot_data(ax, time_axis, filter_input - filtered, style="FIR") + ax.set_xlim((time_axis[0], time_axis[-1])) + plt.title(f"Residuum of Filter ({str(var)})") + plt.legend(loc="upper left") + fig.autofmt_xdate() + plt.tight_layout() + + self.plot_name = f"FIR_{self._name}_{str(var)}_{it0}_{ifilter}_residuum" + self._save() + except Exception as e: + logging.info(f"Could not create plot because of:\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") + pass + + def _plot_t0(self, ax, t0): + ax.axvline(t0, **self.style_dict["t0"]) + + def _plot_series(self, ax, time_axis, data, style): + ax.plot(time_axis, data, **self.style_dict[style]) + + def _plot_data(self, ax, time_axis, data, style="original"): + # original data + self._plot_series(ax, time_axis, data.values.flatten(), style=style) + + def _store_plot_data(self, data): + """Store plot data. Could be loaded in a notebook to redraw.""" + file = os.path.join(self.plot_folder, "plot_data.pickle") + with open(file, "wb") as f: + dill.dump(data, f) diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index 0d81764541eda54697248989aa9e55ae66ff6a5d..1e33d34dfa1d2b437ca6fe1bbf2f38aabb2c202e 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -7,7 +7,7 @@ import math import os import sys import warnings -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Union import matplotlib import matplotlib.pyplot as plt @@ -130,7 +130,7 @@ class PlotMonthlySummary(AbstractPlotClass): # pragma: no cover logging.debug("... start plotting") color_palette = [matplotlib.colors.cnames["green"]] + sns.color_palette("Blues_d", self._window_lead_time).as_hex() - ax = sns.boxplot(x='index', y='values', hue='ahead', data=data.compute(), whis=1., palette=color_palette, + ax = sns.boxplot(x='index', y='values', hue='ahead', data=data.compute(), whis=1.5, palette=color_palette, flierprops={'marker': '.', 'markersize': 1}, showmeans=True, meanprops={'markersize': 1, 'markeredgecolor': 'k'}) ylabel = self._spell_out_chemical_concentrations(target_var) @@ -172,16 +172,17 @@ class PlotConditionalQuantiles(AbstractPlotClass): # pragma: no cover warnings.filterwarnings("ignore", message="Attempted to set non-positive bottom ylim on a log-scaled axis.") def __init__(self, stations: List, data_pred_path: str, plot_folder: str = ".", plot_per_seasons=True, - rolling_window: int = 3, model_name: str = "nn", obs_name: str = "obs", target_var_unit: str = "ppb", - **kwargs): + rolling_window: int = 3, forecast_indicator: str = "nn", obs_indicator: str = "obs", + target_var_unit: str = "ppb", **kwargs): + """Initialise.""" super().__init__(plot_folder, "conditional_quantiles") self._data_pred_path = data_pred_path self._stations = stations self.target_var_unit = target_var_unit self._rolling_window = rolling_window - self._model_name = model_name - self._obs_name = obs_name + self._forecast_indicator = forecast_indicator + self._obs_name = obs_indicator self._opts = self._get_opts(kwargs) self._seasons = ['DJF', 'MAM', 'JJA', 'SON'] if plot_per_seasons is True else "" self._data = self._load_data() @@ -207,7 +208,8 @@ class PlotConditionalQuantiles(AbstractPlotClass): # pragma: no cover for station in self._stations: file = os.path.join(self._data_pred_path, f"forecasts_{station}_test.nc") data_tmp = xr.open_dataarray(file) - data_collector.append(data_tmp.loc[:, :, [self._model_name, self._obs_name]].assign_coords(station=station)) + data_collector.append(data_tmp.loc[:, :, [self._forecast_indicator, + self._obs_name]].assign_coords(station=station)) res = xr.concat(data_collector, dim='station').transpose('index', 'type', 'ahead', 'station') return res @@ -315,7 +317,8 @@ class PlotConditionalQuantiles(AbstractPlotClass): # pragma: no cover """Create seasonal plots.""" for season in self._seasons: try: - self._plot_base(data=self._data.where(self._data['index.season'] == season), x_model=self._model_name, + self._plot_base(data=self._data.where(self._data['index.season'] == season), + x_model=self._forecast_indicator, y_model=self._obs_name, plot_name_affix="cali-ref", season=season) except Exception as e: logging.error(f"Could not create plot PlotConditionalQuantiles._plot_seasons: {season}, cali-ref" @@ -323,7 +326,7 @@ class PlotConditionalQuantiles(AbstractPlotClass): # pragma: no cover f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") try: self._plot_base(data=self._data.where(self._data['index.season'] == season), x_model=self._obs_name, - y_model=self._model_name, plot_name_affix="like-base", season=season) + y_model=self._forecast_indicator, plot_name_affix="like-base", season=season) except Exception as e: logging.error(f"Could not create plot PlotConditionalQuantiles._plot_seasons: {season}, like-base" f" due to the following error:" @@ -331,8 +334,8 @@ class PlotConditionalQuantiles(AbstractPlotClass): # pragma: no cover def _plot_all(self): """Plot overall conditional quantiles on full data.""" - self._plot_base(data=self._data, x_model=self._model_name, y_model=self._obs_name, plot_name_affix="cali-ref") - self._plot_base(data=self._data, x_model=self._obs_name, y_model=self._model_name, plot_name_affix="like-base") + self._plot_base(data=self._data, x_model=self._forecast_indicator, y_model=self._obs_name, plot_name_affix="cali-ref") + self._plot_base(data=self._data, x_model=self._obs_name, y_model=self._forecast_indicator, plot_name_affix="like-base") @TimeTrackingWrapper def _plot_base(self, data: xr.DataArray, x_model: str, y_model: str, plot_name_affix: str, season: str = ""): @@ -413,14 +416,14 @@ class PlotClimatologicalSkillScore(AbstractPlotClass): # pragma: no cover :param plot_folder: path to save the plot (default: current directory) :param score_only: if true plot only scores of CASE I to IV, otherwise plot all single terms (default True) :param extra_name_tag: additional tag that can be included in the plot name (default "") - :param model_setup: architecture type to specify plot name (default "") + :param model_name: architecture type to specify plot name (default "") """ def __init__(self, data: Dict, plot_folder: str = ".", score_only: bool = True, extra_name_tag: str = "", - model_setup: str = ""): + model_name: str = ""): """Initialise.""" - super().__init__(plot_folder, f"skill_score_clim_{extra_name_tag}{model_setup}") + super().__init__(plot_folder, f"skill_score_clim_{extra_name_tag}{model_name}") self._labels = None self._data = self._prepare_data(data, score_only) self._plot(score_only) @@ -461,7 +464,7 @@ class PlotClimatologicalSkillScore(AbstractPlotClass): # pragma: no cover fig, ax = plt.subplots() if not score_only: fig.set_size_inches(11.7, 8.27) - sns.boxplot(x="terms", y="data", hue="ahead", data=self._data, ax=ax, whis=1., palette="Blues_d", + sns.boxplot(x="terms", y="data", hue="ahead", data=self._data, ax=ax, whis=1.5, palette="Blues_d", showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, flierprops={"marker": "."}) ax.axhline(y=0, color="grey", linewidth=.5) ax.set(ylabel=f"{self._label_add(score_only)}skill score", xlabel="", title="summary of all stations", @@ -508,20 +511,13 @@ class PlotCompetitiveSkillScore(AbstractPlotClass): # pragma: no cover #<<<<<<< HEAD def __init__(self, data: pd.DataFrame, plot_folder=".", model_setup="NN", sampling="daily", model_name_for_plots=None): -#======= -# def __init__(self, data: Dict[str, pd.DataFrame], plot_folder=".", model_setup="NN"): -#>>>>>>> develop """Initialise.""" super().__init__(plot_folder, f"skill_score_competitive_{model_setup}") self._model_setup = model_setup self._sampling = self._get_sampling(sampling) self._labels = None -#<<<<<<< HEAD self._model_name_for_plots = model_name_for_plots self._data = self._prepare_data(data) -#======= -# self._data = self._prepare_data(helpers.remove_items(data, "total")) -#>>>>>>> develop default_plot_name = self.plot_name # draw full detail plot self.plot_name = default_plot_name + "_full_detail" @@ -565,7 +561,7 @@ class PlotCompetitiveSkillScore(AbstractPlotClass): # pragma: no cover if self._model_name_for_plots is not None: data.loc[:, 'comparison'] = [i.replace('nn-', f'{self._model_name_for_plots}-') for i in data['comparison']] order = self._create_pseudo_order(data) - sns.boxplot(x="comparison", y="data", hue="ahead", data=data, whis=1., ax=ax, palette="Blues_d", + sns.boxplot(x="comparison", y="data", hue="ahead", data=data, whis=1.5, ax=ax, palette="Blues_d", showmeans=True, meanprops={"markersize": 3, "markeredgecolor": "k"}, flierprops={"marker": "."}, order=order) ax.axhline(y=0, color="grey", linewidth=.5) @@ -582,7 +578,7 @@ class PlotCompetitiveSkillScore(AbstractPlotClass): # pragma: no cover if self._model_name_for_plots is not None: data.loc[:, 'comparison'] = [i.replace('nn-', f'{self._model_name_for_plots}-') for i in data['comparison']] order = self._create_pseudo_order(data) - sns.boxplot(y="comparison", x="data", hue="ahead", data=data, whis=1., ax=ax, palette="Blues_d", + sns.boxplot(y="comparison", x="data", hue="ahead", data=data, whis=1.5, ax=ax, palette="Blues_d", showmeans=True, meanprops={"markersize": 3, "markeredgecolor": "k"}, flierprops={"marker": "."}, order=order) ax.axvline(x=0, color="grey", linewidth=.5) @@ -593,14 +589,13 @@ class PlotCompetitiveSkillScore(AbstractPlotClass): # pragma: no cover def _create_pseudo_order(self, data): """Provide first predefined elements and append all remaining.""" - first_elements = [f"{self._model_setup}-persi", "ols-persi", f"{self._model_setup}-ols"] + first_elements = [f"{self._model_setup} - persi", "ols - persi", f"{self._model_setup} - ols"] first_elements = list(filter(lambda x: x in data.comparison.tolist(), first_elements)) uniq, index = np.unique(first_elements + data.comparison.unique().tolist(), return_index=True) return uniq[index.argsort()] def _filter_comparisons(self, data): - # filtered_headers = list(filter(lambda x: "nn-" in x, data.comparison.unique())) - filtered_headers = list(filter(lambda x: f"{self._model_name_for_plots}-" in x, data.comparison.unique())) + filtered_headers = list(filter(lambda x: f"{self._model_setup} - " in x, data.comparison.unique())) return data[data.comparison.isin(filtered_headers)] @staticmethod @@ -619,15 +614,76 @@ class PlotCompetitiveSkillScore(AbstractPlotClass): # pragma: no cover return lower, upper +class PlotSectorialSkillScore(AbstractPlotClass): # pragma: no cover + + def __init__(self, data: xr.DataArray, plot_folder: str = ".", model_setup: str = "NN", sampling: str = "daily", + model_name_for_plots: Union[str, None] = None, ahead_dim: str = "ahead"): + """Initialise.""" + super().__init__(plot_folder, f"skill_score_sectorial_{model_setup}") + self._model_setup = model_setup + self._sampling = self._get_sampling(sampling) + self._ahead_dim = ahead_dim + self._labels = None + self._model_name_for_plots = model_name_for_plots + self._data = self._prepare_data(data) + self._plot() + self._save() + self.plot_name = self.plot_name + "_vertical" + self._plot_vertical() + self._save() + + def _prepare_data(self, data: xr.DataArray): + self._labels = [str(i) + self._sampling for i in data.coords[self._ahead_dim].values] + data = data.to_dataframe("data")[["data"]].stack(level=0).reset_index(level=3, drop=True).reset_index(name="data") + return data + + def _plot(self): + size = max([len(np.unique(self._data.sector)), 6]) + fig, ax = plt.subplots(figsize=(size, size * 0.8)) + data = self._data + sns.boxplot(x="sector", y="data", hue="ahead", data=data, whis=1.5, ax=ax, palette="Blues_d", + showmeans=True, meanprops={"markersize": 3, "markeredgecolor": "k"}, flierprops={"marker": "."}, + ) + ax.axhline(y=0, color="grey", linewidth=.5) + ax.set(ylabel="skill score", xlabel="sector", title="summary of all stations") + handles, _ = ax.get_legend_handles_labels() + plt.xticks(rotation=45, horizontalalignment="right") + ax.legend(handles, self._labels) + plt.tight_layout() + + def _plot_vertical(self): + """Plot skill scores of the comparisons, but vertically aligned.""" + fig, ax = plt.subplots() + data = self._data + sns.boxplot(y="sector", x="data", hue="ahead", data=data, whis=1.5, ax=ax, palette="Blues_d", + showmeans=True, meanprops={"markersize": 3, "markeredgecolor": "k"}, flierprops={"marker": "."}, + ) + ax.axvline(x=0, color="grey", linewidth=.5) + ax.set(xlabel="skill score", ylabel="sector", title="summary of all stations") + handles, _ = ax.get_legend_handles_labels() + ax.legend(handles, self._labels) + plt.tight_layout() + + @staticmethod + def _lim(data) -> Tuple[float, float]: + """ + Calculate axis limits from data (Can be used to set axis extend). + + Lower limit is the minimum of 0 and data's minimum (reduced by small subtrahend) and upper limit is data's + maximum (increased by a small addend). + + :return: + """ + limit = 5 + lower = np.max([-limit, np.min([0, helpers.float_round(data["data"].min(), 2) - 0.1])]) + upper = np.min([limit, helpers.float_round(data.max()[2], 2) + 0.1]) + return lower, upper + + @TimeTrackingWrapper -class PlotBootstrapSkillScore(AbstractPlotClass): # pragma: no cover +class PlotFeatureImportanceSkillScore(AbstractPlotClass): # pragma: no cover """ - Create plot of climatological skill score after Murphy (1988) as box plot over all stations. - - A forecast time step (called "ahead") is separately shown to highlight the differences for each prediction time - step. Either each single term is plotted (score_only=False) or only the resulting scores CASE I to IV are displayed - (score_only=True, default). Y-axis is adjusted following the data and not hard coded. The plot is saved under - plot_folder path with name skill_score_clim_{extra_name_tag}{model_setup}.pdf and resolution of 500dpi. + Create plot of feature importance analysis. By passing a list `separate_vars` containing variable names, a second plot is created showing the `separate_vars` and the remaining variables side by side with different scaling. @@ -640,23 +696,23 @@ class PlotBootstrapSkillScore(AbstractPlotClass): # pragma: no cover """ - def __init__(self, data: Dict, plot_folder: str = ".", model_setup: str = "", separate_vars: List = None, - sampling: str = "daily", ahead_dim: str = "ahead", bootstrap_type: str = None, - bootstrap_method: str = None): + def __init__(self, data: Dict, plot_folder: str = ".", separate_vars: List = None, sampling: str = "daily", + ahead_dim: str = "ahead", bootstrap_type: str = None, bootstrap_method: str = None, + boot_dim: str = "boots", model_name: str = "NN", branch_names: list = None, ylim: tuple = None): """ Set attributes and create plot. :param data: dictionary with station names as keys and 2D xarrays as values, consist on axis ahead and terms. :param plot_folder: path to save the plot (default: current directory) - :param model_setup: architecture type to specify plot name (default "CNN") :param separate_vars: variables to plot separated (default: ['o3']) :param sampling: type of sampling rate, should be either hourly or daily (default: "daily") :param ahead_dim: name of the ahead dimensions (default: "ahead") :param bootstrap_annotation: additional information to use in the file name (default: None) + :param model_name: architecture type to specify plot name (default "NN") """ annotation = ["_".join([s for s in ["", bootstrap_type, bootstrap_method] if s is not None])][0] - super().__init__(plot_folder, f"skill_score_bootstrap_{model_setup}{annotation}") + super().__init__(plot_folder, f"feature_importance_{model_name}{annotation}") if separate_vars is None: separate_vars = ['o3'] self.sampling = sampling @@ -664,7 +720,6 @@ class PlotBootstrapSkillScore(AbstractPlotClass): # pragma: no cover self._x_name = "boot_var" # <<<<<<< HEAD # self._data = self._prepare_data(data) - self._individual_vars = set(self._data[self._x_name].values) # self._plot() # self._save(bbox_inches='tight') # self.plot_name += '_separated' @@ -672,30 +727,66 @@ class PlotBootstrapSkillScore(AbstractPlotClass): # pragma: no cover # self._save(bbox_inches='tight') # ======= self._ahead_dim = ahead_dim + self._boot_dim = boot_dim self._boot_type = self._set_bootstrap_type(bootstrap_type) self._boot_method = self._set_bootstrap_method(bootstrap_method) + self._number_of_bootstraps = 0 + self._branches_names = branch_names + self._ylim = ylim - self._title = f"Bootstrap analysis ({self._boot_method}, {self._boot_type})" self._data = self._prepare_data(data, sampling) + self._set_title(model_name) + self._individual_vars = set(self._data[self._x_name].values) if "branch" in self._data.columns: plot_name = self.plot_name for branch in self._data["branch"].unique(): - self._title = f"Bootstrap analysis ({self._boot_method}, {self._boot_type}, {branch})" - self._plot(branch=branch) + self._set_title(model_name, branch) self.plot_name = f"{plot_name}_{branch}" - self._save() + try: + self._plot(branch=branch) + self._save() + except ValueError as e: + logging.info(f"Did not plot {self.plot_name} because of {e}") + if len(set(separate_vars).intersection(self._data[self._x_name].unique())) > 0: + self.plot_name += '_separated' + try: + self._plot(branch=branch, separate_vars=separate_vars) + self._save(bbox_inches='tight') + except ValueError as e: + logging.info(f"Did not plot {self.plot_name} because of {e}") else: - self._plot() - self._save() + try: + self._plot() + self._save() + except ValueError as e: + logging.info(f"Did not plot {self.plot_name} because of {e}") if len(set(separate_vars).intersection(self._data[self._x_name].unique())) > 0: self.plot_name += '_separated' - self._plot(separate_vars=separate_vars) - self._save(bbox_inches='tight') + try: + self._plot(separate_vars=separate_vars) + self._save(bbox_inches='tight') + except ValueError as e: + logging.info(f"Did not plot {self.plot_name} because of {e}") @staticmethod def _set_bootstrap_type(boot_type): return {"singleinput": "single input"}.get(boot_type, boot_type) + def _set_title(self, model_name, branch=None): + title_d = {"single input": "Single Inputs", "branch": "Input Branches", "variable": "Variables"} + base_title = f"{model_name}\nImportance of {title_d[self._boot_type]}" + + additional = [] + if branch is not None: + branch_name = self._branches_names[branch] if self._branches_names is not None else branch + additional.append(branch_name) + if self._number_of_bootstraps > 1: + additional.append(f"n={self._number_of_bootstraps}") + additional_title = ", ".join(additional) + if len(additional_title) > 0: + additional_title = f" ({additional_title})" + self._title = base_title + additional_title + @staticmethod def _set_bootstrap_method(boot_method): return {"zero_mean": "zero mean", "shuffle": "shuffled"}.get(boot_method, boot_method) @@ -723,14 +814,31 @@ class PlotBootstrapSkillScore(AbstractPlotClass): # pragma: no cover # ======= station_dim = "station" data = helpers.dict_to_xarray(data, station_dim).sortby(self._x_name) + data = data.transpose(station_dim, self._ahead_dim, self._boot_dim, self._x_name) if self._boot_type == "single input": number_tags = self._get_number_tag(data.coords[self._x_name].values, split_by='_') new_boot_coords = self._return_vars_without_number_tag(data.coords[self._x_name].values, split_by='_', keep=1, as_unique=True) - values = data.values.reshape((data.shape[0], len(new_boot_coords), len(number_tags), data.shape[-1])) - data = xr.DataArray(values, coords={station_dim: data.coords["station"], self._x_name: new_boot_coords, - "branch": number_tags, self._ahead_dim: data.coords[self._ahead_dim]}, - dims=[station_dim, self._x_name, "branch", self._ahead_dim]) + try: + values = data.values.reshape((*data.shape[:3], len(number_tags), len(new_boot_coords))) + data = xr.DataArray(values, coords={station_dim: data.coords[station_dim], self._x_name: new_boot_coords, + "branch": number_tags, self._ahead_dim: data.coords[self._ahead_dim], + self._boot_dim: data.coords[self._boot_dim]}, + dims=[station_dim, self._ahead_dim, self._boot_dim, "branch", self._x_name]) + except ValueError: + data_coll = [] + for nr in number_tags: + filtered_coords = list(filter(lambda x: nr in x.split("_")[0], data.coords[self._x_name].values)) + new_boot_coords = self._return_vars_without_number_tag(filtered_coords, split_by='_', keep=1, + as_unique=True) + sel_data = data.sel({self._x_name: filtered_coords}) + values = sel_data.values.reshape((*data.shape[:3], 1, len(new_boot_coords))) + sel_data = xr.DataArray(values, coords={station_dim: data.coords[station_dim], self._x_name: new_boot_coords, + "branch": [nr], self._ahead_dim: data.coords[self._ahead_dim], + self._boot_dim: data.coords[self._boot_dim]}, + dims=[station_dim, self._ahead_dim, self._boot_dim, "branch", self._x_name]) + data_coll.append(sel_data) + data = xr.concat(data_coll, "branch") else: try: new_boot_coords = self._return_vars_without_number_tag(data.coords[self._x_name].values, split_by='_', @@ -742,7 +850,8 @@ class PlotBootstrapSkillScore(AbstractPlotClass): # pragma: no cover self._labels = [str(i) + sampling_letter for i in data.coords[self._ahead_dim].values] if station_dim not in data.dims: data = data.expand_dims(station_dim) - return data.to_dataframe("data").reset_index(level=np.arange(len(data.dims)).tolist()) + self._number_of_bootstraps = np.unique(data.coords[self._boot_dim].values).shape[0] + return data.to_dataframe("data").reset_index(level=np.arange(len(data.dims)).tolist()).dropna() @staticmethod def _get_target_sampling(sampling, pos): @@ -791,16 +900,16 @@ class PlotBootstrapSkillScore(AbstractPlotClass): # pragma: no cover if separate_vars is None: self._plot_all_variables(branch) else: - self._plot_selected_variables(separate_vars) + self._plot_selected_variables(separate_vars, branch) - def _plot_selected_variables(self, separate_vars: List): - data = self._data - self.raise_error_if_separate_vars_do_not_exist(data, separate_vars, self._x_name) + def _plot_selected_variables(self, separate_vars: List, branch=None): + data = self._data if branch is None else self._data[self._data["branch"] == str(branch)] + self.raise_error_if_vars_do_not_exist(data, separate_vars, self._x_name, name="separate_vars") all_variables = self._get_unique_values_from_column_of_df(data, self._x_name) remaining_vars = helpers.remove_items(all_variables, separate_vars) -# <<<<<<< HEAD # data_first = self._select_data(df=data, variables=separate_vars, column_name='boot_var') # data_second = self._select_data(df=data, variables=remaining_vars, column_name='boot_var') + self.raise_error_if_vars_do_not_exist(data, remaining_vars, self._x_name, name="remaining_vars") order_first = self.set_order_for_x_axis(separate_vars) order_second, center_names_second = self.set_order_for_x_axis(remaining_vars, return_center_names=True) number_of_vars_second = len(order_second) @@ -810,8 +919,6 @@ class PlotBootstrapSkillScore(AbstractPlotClass): # pragma: no cover figsize = (len(self._individual_vars) / 2, 10) else: figsize = (15, 10) - -# ======= data_first = self._select_data(df=data, variables=separate_vars, column_name=self._x_name) data_second = self._select_data(df=data, variables=remaining_vars, column_name=self._x_name) @@ -819,37 +926,31 @@ class PlotBootstrapSkillScore(AbstractPlotClass): # pragma: no cover figsize=figsize, gridspec_kw={'width_ratios': [len(separate_vars), len(remaining_vars)]}) -# >>>>>>> develop if len(separate_vars) > 1: first_box_width = .8 else: - first_box_width = 2. + first_box_width = .8 -# <<<<<<< HEAD -# sns.boxplot(x=self._x_name, y="data", hue="ahead", data=data_first, ax=ax[0], whis=1., palette="Blues_d", -# showmeans=True, order=order_first, meanprops={"markersize": 1, "markeredgecolor": "k"}, -# flierprops={"marker": "."}, width=first_box_width -# ) -# ax[0].set(ylabel=f"skill score", xlabel="") -# -# sns.boxplot(x=self._x_name, y="data", hue="ahead", data=data_second, ax=ax[1], whis=1., palette="Blues_d", -# showmeans=True, order=order_second, meanprops={"markersize": 1, "markeredgecolor": "k"}, -# flierprops={"marker": "."}, -# ) -# ======= sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=data_first, ax=ax[0], whis=1., palette="Blues_d", showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, - flierprops={"marker": "."}, width=first_box_width) + showfliers=False, width=first_box_width) ax[0].set(ylabel=f"skill score", xlabel="") + if self._ylim is not None: + _ylim = self._ylim if isinstance(self._ylim, tuple) else self._ylim[0] + ax[0].set(ylim=_ylim) - sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=data_second, ax=ax[1], whis=1., + sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=data_second, ax=ax[1], whis=1.5, palette="Blues_d", showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, - flierprops={"marker": "."}) -# >>>>>>> develop + showfliers=False, flierprops={"marker": "."}) + ax[1].set(ylabel="", xlabel="") if group_size > 1: [ax[1].axvline(x + .5, color='grey') for i, x in enumerate(ax[1].get_xticks(), start=1) if i % group_size == 0] ax[1].yaxis.tick_right() + if self._ylim is not None and isinstance(self._ylim, list): + _ylim = self._ylim[1] + ax[1].set(ylim=_ylim) + handles, _ = ax[1].get_legend_handles_labels() for sax in ax: matplotlib.pyplot.sca(sax) @@ -857,7 +958,8 @@ class PlotBootstrapSkillScore(AbstractPlotClass): # pragma: no cover plt.xticks(rotation=45, ha='right') sax.legend_.remove() - fig.legend(handles, self._labels, loc='upper center', ncol=len(handles) + 1, ) + # fig.legend(handles, self._labels, loc='upper center', ncol=len(handles) + 1, ) + ax[1].legend(handles, self._labels, loc='lower center', ncol=len(handles) + 1, fontsize="medium") def align_yaxis(ax1, ax2): """ @@ -882,6 +984,7 @@ class PlotBootstrapSkillScore(AbstractPlotClass): # pragma: no cover align_yaxis(ax[0], ax[1]) align_yaxis(ax[0], ax[1]) + plt.subplots_adjust(right=0.8) plt.title(self._title) @staticmethod @@ -895,9 +998,13 @@ class PlotBootstrapSkillScore(AbstractPlotClass): # pragma: no cover selected_data = pd.concat([selected_data, tmp_var], axis=0) return selected_data - def raise_error_if_separate_vars_do_not_exist(self, data, separate_vars, column_name): - if not self._variables_exist_in_df(df=data, variables=separate_vars, column_name=column_name): - raise ValueError(f"At least one entry of `separate_vars' does not exist in `self.data' ") + def raise_error_if_vars_do_not_exist(self, data, vars, column_name, name="separate_vars"): + if len(vars) == 0: + msg = f"No variables are given for `{name}' to check in `self.data' " + raise ValueError(msg) + if not self._variables_exist_in_df(df=data, variables=vars, column_name=column_name): + msg = f"At least one entry of `{name}' does not exist in `self.data' " + raise ValueError(msg) @staticmethod def _get_unique_values_from_column_of_df(df: pd.DataFrame, column_name: str) -> List: @@ -911,33 +1018,6 @@ class PlotBootstrapSkillScore(AbstractPlotClass): # pragma: no cover """ """ -# # <<<<<<< HEAD -# number_of_vars = len(self._individual_vars) -# order, center_names = self.set_order_for_x_axis(self._individual_vars, return_center_names=True) -# group_size = int(number_of_vars / len(center_names)) -# -# if number_of_vars > 20: -# fig, ax = plt.subplots(figsize=(number_of_vars/2, 10)) -# else: -# fig, ax = plt.subplots(figsize=(15, 10)) -# sns.boxplot(x=self._x_name, y="data", hue="ahead", data=self._data, ax=ax, whis=1., palette="Blues_d", -# showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, flierprops={"marker": "."}, -# order=order) -# ax.axhline(y=0, color="grey", linewidth=.5) -# if group_size > 1: -# [ax.axvline(x + .5, color='grey') for i, x in enumerate(ax.get_xticks(), start=1) if i % group_size == 0] -# plt.xticks(rotation=45, horizontalalignment="right") -# ax.set(ylabel=f"skill score", xlabel="", title="summary of all stations") -# # ======= -# fig, ax = plt.subplots() -# plot_data = self._data if branch is None else self._data[self._data["branch"] == str(branch)] -# sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=plot_data, ax=ax, whis=1., palette="Blues_d", -# showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, flierprops={"marker": "."}) -# ax.axhline(y=0, color="grey", linewidth=.5) -# plt.xticks(rotation=45) -# ax.set(ylabel=f"skill score", xlabel="", title=self._title) -# >>>>>>> develop - # ToDo B: idea to solve merge conflict number_of_vars = len(self._individual_vars) order, center_names = self.set_order_for_x_axis(self._individual_vars, return_center_names=True) @@ -949,15 +1029,37 @@ class PlotBootstrapSkillScore(AbstractPlotClass): # pragma: no cover fig, ax = plt.subplots(figsize=(15, 10)) plot_data = self._data if branch is None else self._data[self._data["branch"] == str(branch)] - sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=plot_data, ax=ax, whis=1., palette="Blues_d", - showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, flierprops={"marker": "."}) + if self._boot_type == "branch": + fig, ax = plt.subplots(figsize=(0.5 + 2 / len(plot_data[self._x_name].unique()) + len(plot_data[self._x_name].unique()),4)) + sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=plot_data, ax=ax, whis=1., + palette="Blues_d", showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, + showfliers=False, width=0.8) + else: + fig, ax = plt.subplots() + sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=plot_data, ax=ax, whis=1.5, palette="Blues_d", + showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, showfliers=False) ax.axhline(y=0, color="grey", linewidth=.5) +# <<<<<<< HEAD if group_size > 1: [ax.axvline(x + .5, color='grey') for i, x in enumerate(ax.get_xticks(), start=1) if i % group_size == 0] plt.xticks(rotation=45, horizontalalignment="right") ax.set(ylabel=f"skill score", xlabel="", title="summary of all stations") - # ToDo E: idea to solve merge conflict +# ======= + + if self._ylim is not None: + if isinstance(self._ylim, tuple): + _ylim = self._ylim + else: + _ylim = (min(self._ylim[0][0], self._ylim[1][0]), max(self._ylim[0][1], self._ylim[1][1])) + ax.set(ylim=_ylim) + + if self._boot_type == "branch": + plt.xticks() + else: + plt.xticks(rotation=45) + ax.set(ylabel=f"skill score", xlabel="", title=self._title) +# >>>>>>> origin/develop handles, _ = ax.get_legend_handles_labels() ax.legend(handles, self._labels) plt.tight_layout() @@ -1141,7 +1243,6 @@ class PlotSeparationOfScales(AbstractPlotClass): # pragma: no cover data = dh.get_X(as_numpy=False)[0] station = dh.id_class.station[0] data = data.sel(Stations=station) - # plt.subplots() data.plot(x=self.time_dim, y=self.window_dim, col=self.filter_dim, row=self.target_dim, robust=True) self.plot_name = f"{orig_plot_name}_{station}" self._save() @@ -1152,7 +1253,7 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover def __init__(self, data: xr.DataArray, plot_folder: str = ".", model_type_dim: str = "type", error_measure: str = "mse", error_unit: str = None, dim_name_boots: str = 'boots', - block_length: str = None): + block_length: str = None, model_name: str = "NN", model_indicator: str = "nn"): super().__init__(plot_folder, "sample_uncertainty_from_bootstrap") default_name = self.plot_name self.model_type_dim = model_type_dim @@ -1160,6 +1261,7 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover self.dim_name_boots = dim_name_boots self.error_unit = error_unit self.block_length = block_length + data = self.rename_model_indicator(data, model_name, model_indicator) self.prepare_data(data) self._plot(orientation="v") @@ -1177,10 +1279,14 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover self._data_table = None self._n_boots = None + def rename_model_indicator(self, data, model_name, model_indicator): + data.coords[self.model_type_dim] = [{model_indicator: model_name}.get(n, n) + for n in data.coords[self.model_type_dim].values] + return data + def prepare_data(self, data: xr.DataArray): - self._data_table = data.to_pandas() - if "persi" in self._data_table.columns: - self._data_table["persi"] = self._data_table.pop("persi") + data_table = data.to_pandas() + self._data_table = data_table[data_table.mean().sort_values().index] self._n_boots = self._data_table.shape[0] def _apply_root(self): @@ -1195,11 +1301,11 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover if orientation == "v": figsize, width = (size, 5), 0.4 elif orientation == "h": - figsize, width = (6, (1+.5*size)), 0.65 + figsize, width = (7, (1+.5*size)), 0.65 else: raise ValueError(f"orientation must be `v' or `h' but is: {orientation}") fig, ax = plt.subplots(figsize=figsize) - sns.boxplot(data=data_table, ax=ax, whis=1., color="white", + sns.boxplot(data=data_table, ax=ax, whis=1.5, color="white", showmeans=True, meanprops={"markersize": 6, "markeredgecolor": "k"}, flierprops={"marker": "o", "markerfacecolor": "black", "markeredgecolor": "none", "markersize": 3}, boxprops={'facecolor': 'none', 'edgecolor': 'k'}, @@ -1212,7 +1318,8 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover else: raise ValueError(f"orientation must be `v' or `h' but is: {orientation}") text = f"n={n_boots}" if self.block_length is None else f"{self.block_length}, n={n_boots}" - text_box = AnchoredText(text, frameon=True, loc=1, pad=0.5) + loc = "upper right" if orientation == "h" else "upper left" + text_box = AnchoredText(text, frameon=True, loc=loc, pad=0.5) plt.setp(text_box.patch, edgecolor='k', facecolor='w') ax.add_artist(text_box) plt.setp(ax.lines, color='k') diff --git a/mlair/plotting/training_monitoring.py b/mlair/plotting/training_monitoring.py index b2b531b99c85bb43e4e758fd23045c9f0575cb24..39dd80651226519463d7b503fb612e43983d73cf 100644 --- a/mlair/plotting/training_monitoring.py +++ b/mlair/plotting/training_monitoring.py @@ -45,15 +45,18 @@ class PlotModelHistory: self._additional_columns = self._filter_columns(history) self._plot(filename) - @staticmethod - def _get_plot_metric(history, plot_metric, main_branch): - if plot_metric.lower() == "mse": - plot_metric = "mean_squared_error" - elif plot_metric.lower() == "mae": - plot_metric = "mean_absolute_error" + def _get_plot_metric(self, history, plot_metric, main_branch, correct_names=True): + _plot_metric = plot_metric + if correct_names is True: + if plot_metric.lower() == "mse": + plot_metric = "mean_squared_error" + elif plot_metric.lower() == "mae": + plot_metric = "mean_absolute_error" available_keys = [k for k in history.keys() if plot_metric in k and ("main" in k.lower() if main_branch else True)] available_keys.sort(key=len) + if len(available_keys) == 0 and correct_names is True: + return self._get_plot_metric(history, _plot_metric, main_branch, correct_names=False) return available_keys[0] def _filter_columns(self, history: Dict) -> List[str]: diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py index 9f9e9bc02132991986802ee2ec75891e1910dbaf..845eeceef54da5e9a1b5c382bf2a9facb6857b74 100644 --- a/mlair/run_modules/experiment_setup.py +++ b/mlair/run_modules/experiment_setup.py @@ -195,6 +195,9 @@ class ExperimentSetup(RunEnvironment): :param use_multiprocessing: Enable parallel preprocessing (postprocessing not implemented yet) by setting this parameter to `True` (default). If set to `False` the computation is performed in an serial approach. Multiprocessing is disabled when running in debug mode and cannot be switched on. + :param transformation_file: Use transformation options from this file for transformation + :param calculate_fresh_transformation: can either be True or False, indicates if new transformation options should + be calculated in any case (transformation_file is not used in this case!). """ @@ -232,7 +235,9 @@ class ExperimentSetup(RunEnvironment): max_number_multiprocessing: int = None, start_script: Union[Callable, str] = None, overwrite_lazy_data: bool = None, uncertainty_estimate_block_length: str = None, uncertainty_estimate_evaluate_competitors: bool = None, uncertainty_estimate_n_boots: int = None, - do_uncertainty_estimate: bool = None, target_var_unit: str = None, **kwargs): + + do_uncertainty_estimate: bool = None, model_display_name: str = None, transformation_file: str = None, + calculate_fresh_transformation: bool = None, target_var_unit: str = None, **kwargs): # create run framework super().__init__() @@ -319,6 +324,9 @@ class ExperimentSetup(RunEnvironment): scope="preprocessing") self._set_param("transformation", transformation, default={}) self._set_param("transformation", None, scope="preprocessing") + self._set_param("transformation_file", transformation_file, default=None) + if calculate_fresh_transformation is not None: + self._set_param("calculate_fresh_transformation", calculate_fresh_transformation) self._set_param("data_handler", data_handler, default=DefaultDataHandler) # iter and window dimension @@ -386,6 +394,8 @@ class ExperimentSetup(RunEnvironment): default=DEFAULT_FEATURE_IMPORTANCE_BOOTSTRAP_TYPE, scope="feature_importance") self._set_param("plot_list", plot_list, default=DEFAULT_PLOT_LIST, scope="general.postprocessing") + if model_display_name is not None: + self._set_param("model_display_name", model_display_name) self._set_param("neighbors", ["DEBW030"]) # TODO: just for testing # set competitors diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index e148c9f1bef8b0d8b5fd14686dc10588d01c8b58..d0646de163484706a66387197260dc891353dcde 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -14,16 +14,17 @@ import tensorflow.keras as keras import numpy as np import pandas as pd import xarray as xr +import dask.array as da from mlair.configuration import path_config from mlair.data_handler import Bootstraps, KerasIterator from mlair.helpers.datastore import NameNotFoundInDataStore -from mlair.helpers import TimeTracking, statistics, extract_value, remove_items, to_list, tables +from mlair.helpers import TimeTracking, TimeTrackingWrapper, statistics, extract_value, remove_items, to_list, tables from mlair.model_modules.linear_model import OrdinaryLeastSquaredModel from mlair.model_modules import AbstractModelClass from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotClimatologicalSkillScore, \ - PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotConditionalQuantiles, \ - PlotSeparationOfScales, PlotSampleUncertaintyFromBootstrap + PlotCompetitiveSkillScore, PlotTimeSeries, PlotFeatureImportanceSkillScore, PlotConditionalQuantiles, \ + PlotSeparationOfScales, PlotSampleUncertaintyFromBootstrap, PlotSectorialSkillScore from mlair.plotting.data_insight_plotting import PlotStationMap, PlotAvailability, PlotAvailabilityHistogram, \ PlotPeriodogram, PlotDataHistogram from mlair.run_modules.run_environment import RunEnvironment @@ -68,7 +69,7 @@ class PostProcessing(RunEnvironment): def __init__(self): """Initialise and run post-processing.""" super().__init__() - self.model: keras.Model = self._load_model() + self.model: AbstractModelClass = self._load_model() self.model_name = self.data_store.get("model_name", "model").rsplit("/", 1)[1].split(".", 1)[0] self.ols_model = None self.batch_size: int = self.data_store.get_default("batch_size", "model", 64) @@ -85,6 +86,7 @@ class PostProcessing(RunEnvironment): self._sampling = self.data_store.get("sampling") self.window_lead_time = extract_value(self.data_store.get("output_shape", "model")) self.skill_scores = None + self.skill_score_per_sector = None self.feature_importance_skill_scores = None self.uncertainty_estimate = None self.competitor_path = self.data_store.get("competitor_path") @@ -100,6 +102,9 @@ class PostProcessing(RunEnvironment): self.uncertainty_estimate_boot_dim = "boots" self.model_type_dim = "type" self.index_dim = "index" + self.iter_dim = self.data_store.get("iter_dim") + self.upstream_wind_sector = None + self.model_display_name = self.data_store.get_default("model_display_name", default=self.model.model_name) self._run() def _run(self): @@ -110,6 +115,13 @@ class PostProcessing(RunEnvironment): self.make_prediction(self.test_data) self.make_prediction(self.train_val_data) + # load upstream wind sector for test_data + try: + self.load_upstream_wind_sector(name_of_set="test") + self.skill_score_per_sector = self.calculate_error_metrics_based_on_upstream_wind_dir() + except Exception as e: + logging.info(f"Can not process upsstream wind sectors due to: {e}") + # calculate error metrics on test data self.calculate_test_score() @@ -119,17 +131,17 @@ class PostProcessing(RunEnvironment): # feature importance bootstraps if self.data_store.get("evaluate_feature_importance", "postprocessing"): - with TimeTracking(name="calculate feature importance using bootstraps"): + with TimeTracking(name="evaluate_feature_importance", log_on_enter=True): create_new_bootstraps = self.data_store.get("create_new_bootstraps", "feature_importance") bootstrap_method = self.data_store.get("bootstrap_method", "feature_importance") bootstrap_type = self.data_store.get("bootstrap_type", "feature_importance") self.calculate_feature_importance(create_new_bootstraps, bootstrap_type=bootstrap_type, bootstrap_method=bootstrap_method) - if self.feature_importance_skill_scores is not None: - self.report_feature_importance_results(self.feature_importance_skill_scores) + if self.feature_importance_skill_scores is not None: + self.report_feature_importance_results(self.feature_importance_skill_scores) # skill scores and error metrics - with TimeTracking(name="calculate skill scores"): + with TimeTracking(name="calculate_error_metrics", log_on_enter=True): skill_score_competitive, _, skill_score_climatological, errors = self.calculate_error_metrics() self.skill_scores = (skill_score_competitive, skill_score_climatological) self.report_error_metrics(errors) @@ -139,12 +151,25 @@ class PostProcessing(RunEnvironment): # plotting self.plot() + def load_upstream_wind_sector(self, name_of_set): + path = os.path.join(self.data_store.get("experiment_path"), + f"data/*_{self.data_store.get('start', name_of_set)}_{self.data_store.get('end', name_of_set)}_upstream_wind_sector.nc") + iter_dim = self.data_store.get("iter_dim") + ds = xr.open_mfdataset(path).to_array(iter_dim) + try: + ds = ds.rename({'XTIME': self.index_dim}) + except ValueError as e: + logging.warning("Dimension `XTIME' does not exist") + self.upstream_wind_sector = ds + + @TimeTrackingWrapper def estimate_sample_uncertainty(self, separate_ahead=False): """ Estimate sample uncertainty by using a bootstrap approach. Forecasts are split into individual blocks along time and randomly drawn with replacement. The resulting behaviour of the error indicates the robustness of each analyzed model to quantify which model might be superior compared to others. """ + logging.info("start estimate_sample_uncertainty") n_boots = self.data_store.get_default("n_boots", default=1000, scope="uncertainty_estimate") block_length = self.data_store.get_default("block_length", default="1m", scope="uncertainty_estimate") evaluate_competitors = self.data_store.get_default("evaluate_competitors", default=True, @@ -277,7 +302,8 @@ class PostProcessing(RunEnvironment): if _iter == 0: self.feature_importance_skill_scores = {} for boot_type in to_list(bootstrap_type): - self.feature_importance_skill_scores[boot_type] = {} + if _iter == 0: + self.feature_importance_skill_scores[boot_type] = {} for boot_method in to_list(bootstrap_method): try: if create_new_bootstraps: @@ -286,13 +312,13 @@ class PostProcessing(RunEnvironment): boot_skill_score = self.calculate_feature_importance_skill_scores(bootstrap_type=boot_type, bootstrap_method=boot_method) self.feature_importance_skill_scores[boot_type][boot_method] = boot_skill_score - except FileNotFoundError: + except (FileNotFoundError, ValueError, OSError): if _iter != 0: - raise RuntimeError(f"calculate_feature_importance ({boot_type}, {boot_type}) was called for the " - f"2nd time. This means, that something internally goes wrong. Please check " - f"for possible errors") - logging.info(f"Could not load all files for feature importance ({boot_type}, {boot_type}), restart " - f"calculate_feature_importance with create_new_bootstraps=True.") + raise RuntimeError(f"calculate_feature_importance ({boot_type}, {boot_method}) was called for " + f"the 2nd time. This means, that something internally goes wrong. Please " + f"check for possible errors.") + logging.info(f"Could not load all files for feature importance ({boot_type}, {boot_method}), " + f"restart calculate_feature_importance with create_new_bootstraps=True.") self.calculate_feature_importance(True, _iter=1, bootstrap_type=boot_type, bootstrap_method=boot_method) @@ -303,26 +329,34 @@ class PostProcessing(RunEnvironment): These forecasts are saved in bootstrap_path with the names `bootstraps_{var}_{station}.nc` and `bootstraps_labels_{station}.nc`. """ + + def _reshape(d, pos): + if isinstance(d, list): + return list(map(lambda x: _reshape(x, pos), d)) + else: + return d[..., pos] + # forecast with TimeTracking(name=f"{inspect.stack()[0].function} ({bootstrap_type}, {bootstrap_method})"): # extract all requirements from data store forecast_path = self.data_store.get("forecast_path") number_of_bootstraps = self.data_store.get("n_boots", "feature_importance") - dims = [self.index_dim, self.ahead_dim, self.model_type_dim] + dims = [self.uncertainty_estimate_boot_dim, self.index_dim, self.ahead_dim, self.model_type_dim] for station in self.test_data: X, Y = None, None bootstraps = Bootstraps(station, number_of_bootstraps, bootstrap_type=bootstrap_type, bootstrap_method=bootstrap_method) + number_of_bootstraps = bootstraps.number_of_bootstraps for boot in bootstraps: X, Y, (index, dimension) = boot # make bootstrap predictions - bootstrap_predictions = self.model.predict(X) - if isinstance(bootstrap_predictions, list): # if model is branched model - bootstrap_predictions = bootstrap_predictions[-1] + bootstrap_predictions = [self.model.predict(_reshape(X, pos)) for pos in range(number_of_bootstraps)] + if isinstance(bootstrap_predictions[0], list): # if model is branched model + bootstrap_predictions = list(map(lambda x: x[-1], bootstrap_predictions)) # save bootstrap predictions separately for each station and variable combination - bootstrap_predictions = np.expand_dims(bootstrap_predictions, axis=-1) - shape = bootstrap_predictions.shape - coords = (range(shape[0]), range(1, shape[1] + 1)) + bootstrap_predictions = list(map(lambda x: np.expand_dims(x, axis=-1), bootstrap_predictions)) + shape = bootstrap_predictions[0].shape + coords = (range(number_of_bootstraps), range(shape[0]), range(1, shape[1] + 1)) var = f"{index}_{dimension}" if index is not None else str(dimension) tmp = xr.DataArray(bootstrap_predictions, coords=(*coords, [var]), dims=dims) file_name = os.path.join(forecast_path, @@ -330,9 +364,9 @@ class PostProcessing(RunEnvironment): tmp.to_netcdf(file_name) else: # store also true labels for each station - labels = np.expand_dims(Y, axis=-1) + labels = np.expand_dims(Y[..., 0], axis=-1) file_name = os.path.join(forecast_path, f"bootstraps_{station}_{bootstrap_method}_labels.nc") - labels = xr.DataArray(labels, coords=(*coords, [self.observation_indicator]), dims=dims) + labels = xr.DataArray(labels, coords=(*coords[1:], [self.observation_indicator]), dims=dims[1:]) labels.to_netcdf(file_name) def calculate_feature_importance_skill_scores(self, bootstrap_type, bootstrap_method) -> Dict[str, xr.DataArray]: @@ -363,16 +397,13 @@ class PostProcessing(RunEnvironment): file_name = os.path.join(forecast_path, f"bootstraps_{str(station)}_{bootstrap_method}_labels.nc") with xr.open_dataarray(file_name) as da: labels = da.load() - shape = labels.shape # get original forecasts - orig = self.get_orig_prediction(forecast_path, forecast_file % str(station), number_of_bootstraps) - orig = orig.reshape(shape) - coords = (range(shape[0]), range(1, shape[1] + 1), [reference_name]) - orig = xr.DataArray(orig, coords=coords, dims=[self.index_dim, self.ahead_dim, self.model_type_dim]) + orig = self.get_orig_prediction(forecast_path, forecast_file % str(station), reference_name=reference_name) + orig.coords[self.index_dim] = labels.coords[self.index_dim] # calculate skill scores for each variable - skill = pd.DataFrame(columns=range(1, self.window_lead_time + 1)) + skill = [] for boot_set in bootstrap_iter: boot_var = f"{boot_set[0]}_{boot_set[1]}" if isinstance(boot_set, tuple) else str(boot_set) file_name = os.path.join(forecast_path, @@ -385,45 +416,51 @@ class PostProcessing(RunEnvironment): data = boot_data.sel({self.ahead_dim: ahead}) boot_scores.append( skill_scores.general_skill_score(data, forecast_name=boot_var, - reference_name=reference_name)) - skill.loc[boot_var] = np.array(boot_scores) + reference_name=reference_name, dim=self.index_dim)) + tmp = xr.DataArray(np.expand_dims(np.array(boot_scores), axis=-1), + coords={self.ahead_dim: range(1, self.window_lead_time + 1), + self.uncertainty_estimate_boot_dim: range(number_of_bootstraps), + self.boot_var_dim: [boot_var]}, + dims=[self.ahead_dim, self.uncertainty_estimate_boot_dim, self.boot_var_dim]) + skill.append(tmp) # collect all results in single dictionary - score[str(station)] = xr.DataArray(skill, dims=[self.boot_var_dim, self.ahead_dim]) + score[str(station)] = xr.concat(skill, dim=self.boot_var_dim) return score - def get_orig_prediction(self, path, file_name, number_of_bootstraps, prediction_name=None): + def get_orig_prediction(self, path, file_name, prediction_name=None, reference_name=None): if prediction_name is None: prediction_name = self.forecast_indicator file = os.path.join(path, file_name) with xr.open_dataarray(file) as da: - prediction = da.load().sel(type=prediction_name).squeeze() - return self.repeat_data(prediction, number_of_bootstraps) + prediction = da.load().sel({self.model_type_dim: [prediction_name]}) + if reference_name is not None: + prediction.coords[self.model_type_dim] = [reference_name] + return prediction.dropna(dim=self.index_dim) @staticmethod def repeat_data(data, number_of_repetition): if isinstance(data, xr.DataArray): data = data.data - vals = np.tile(data, (number_of_repetition, 1)) - return vals[~np.isnan(vals).any(axis=1), :] + return np.repeat(np.expand_dims(data, axis=-1), number_of_repetition, axis=-1) def _get_model_name(self): """Return model name without path information.""" return self.data_store.get("model_name", "model").rsplit("/", 1)[1].split(".", 1)[0] - def _load_model(self) -> keras.models: + def _load_model(self) -> AbstractModelClass: """ Load NN model either from data store or from local path. :return: the model """ - try: + try: # is only available if a model was trained in training stage model = self.data_store.get("best_model") except NameNotFoundInDataStore: logging.info("No model was saved in data store. Try to load model from experiment path.") model_name = self.data_store.get("model_name", "model") - model_class: AbstractModelClass = self.data_store.get("model", "model") - model = keras.models.load_model(model_name, custom_objects=model_class.custom_objects) + model: AbstractModelClass = self.data_store.get("model", "model") + model.load_model(model_name) return model # noinspection PyBroadException @@ -467,26 +504,29 @@ class PostProcessing(RunEnvironment): f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") try: - if (self.feature_importance_skill_scores is not None) and ("PlotBootstrapSkillScore" in plot_list): + if (self.feature_importance_skill_scores is not None) and ("PlotFeatureImportanceSkillScore" in plot_list): for boot_type, boot_data in self.feature_importance_skill_scores.items(): for boot_method, boot_skill_score in boot_data.items(): try: - PlotBootstrapSkillScore(boot_skill_score, plot_folder=self.plot_path, - model_setup=self.forecast_indicator, sampling=self._sampling, - ahead_dim=self.ahead_dim, separate_vars=to_list(self.target_var), - bootstrap_type=boot_type, bootstrap_method=boot_method) + PlotFeatureImportanceSkillScore( + boot_skill_score, plot_folder=self.plot_path, model_name=self.model_display_name, + sampling=self._sampling, ahead_dim=self.ahead_dim, + separate_vars=to_list(self.target_var), bootstrap_type=boot_type, + bootstrap_method=boot_method) except Exception as e: - logging.error(f"Could not create plot PlotBootstrapSkillScore ({boot_type}, {boot_method}) " - f"due to the following error:\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n" - f"{sys.exc_info()[2]}") - + logging.error(f"Could not create plot PlotFeatureImportanceSkillScore ({boot_type}, " + f"{boot_method}) due to the following error:\n{sys.exc_info()[0]}\n" + f"{sys.exc_info()[1]}\n{sys.exc_info()[2]}") except Exception as e: - logging.error(f"Could not create plot PlotBootstrapSkillScore due to the following error: {e}") + logging.error(f"Could not create plot PlotFeatureImportanceSkillScore due to the following error: {e}") try: if "PlotConditionalQuantiles" in plot_list: PlotConditionalQuantiles(self.test_data.keys(), data_pred_path=path, plot_folder=self.plot_path, - target_var_unit=self.target_var_unit) + target_var_unit=self.target_var_unit, + forecast_indicator=self.forecast_indicator, + obs_indicator=self.observation_indicator) + except Exception as e: logging.error(f"Could not create plot PlotConditionalQuantiles due to the following error:" f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") @@ -502,9 +542,9 @@ class PostProcessing(RunEnvironment): try: if "PlotClimatologicalSkillScore" in plot_list: PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, - model_setup=self.forecast_indicator) + model_name=self.model_display_name) PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, score_only=False, - extra_name_tag="all_terms_", model_setup=self.forecast_indicator) + extra_name_tag="all_terms_", model_name=self.model_display_name) except Exception as e: logging.error(f"Could not create plot PlotClimatologicalSkillScore due to the following error: {e}" f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") @@ -512,12 +552,23 @@ class PostProcessing(RunEnvironment): try: if "PlotCompetitiveSkillScore" in plot_list: PlotCompetitiveSkillScore(self.skill_scores[0], plot_folder=self.plot_path, - model_setup=self.forecast_indicator, sampling=self._sampling, + model_setup=self.model_display_name, sampling=self._sampling, model_name_for_plots=self.model_name_for_plots) + except Exception as e: logging.error(f"Could not create plot PlotCompetitiveSkillScore due to the following error: {e}" f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") + try: + if "PlotSectorialSkillScore" in plot_list: + PlotSectorialSkillScore(self.skill_score_per_sector, plot_folder=self.plot_path, + model_setup=self.model_display_name, sampling=self._sampling, + model_name_for_plots=self.model_name_for_plots, ahead_dim=self.ahead_dim + ) + except Exception as e: + logging.error(f"Could not create plot PlotSectorialSkillScore due to the following error: {e}" + f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") + try: if "PlotTimeSeries" in plot_list: PlotTimeSeries(self.test_data.keys(), path, r"forecasts_%s_test.nc", plot_folder=self.plot_path, @@ -583,11 +634,13 @@ class PostProcessing(RunEnvironment): try: if "PlotSampleUncertaintyFromBootstrap" in plot_list and self.uncertainty_estimate is not None: - block_length= self.data_store.get_default("block_length", default="1m", scope="uncertainty_estimate") + block_length = self.data_store.get_default("block_length", default="1m", scope="uncertainty_estimate") PlotSampleUncertaintyFromBootstrap( data=self.uncertainty_estimate, plot_folder=self.plot_path, model_type_dim=self.model_type_dim, dim_name_boots=self.uncertainty_estimate_boot_dim, error_measure="mean squared error", - error_unit=fr"{self.target_var_unit}$^2$", block_length=block_length) + error_unit=fr"{self.target_var_unit}$^2$", block_length=block_length, + model_name=self.model_display_name, model_indicator=self.forecast_indicator) + except Exception as e: logging.error(f"Could not create plot PlotSampleUncertaintyFromBootstrap due to the following error: {e}" f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") @@ -620,6 +673,7 @@ class PostProcessing(RunEnvironment): logging.info(f"start make_prediction for {subset_type}") time_dimension = self.data_store.get("time_dim") window_dim = self.data_store.get("window_dim") + path = self.data_store.get("forecast_path") subset_type = subset.name for i, data in enumerate(subset): input_data = data.get_X() @@ -659,7 +713,6 @@ class PostProcessing(RunEnvironment): **prediction_dict) # save all forecasts locally - path = self.data_store.get("forecast_path") prefix = "forecasts_norm" if normalised is True else "forecasts" file = os.path.join(path, f"{prefix}_{str(data)}_{subset_type}.nc") all_predictions.to_netcdf(file) @@ -879,6 +932,45 @@ class PostProcessing(RunEnvironment): except (TypeError, AttributeError): return forecast if competitor is None else competitor + @TimeTrackingWrapper + def load_forecast_array_and_store_as_dataset(self): + path = self.data_store.get("forecast_path") + all_stations = self.data_store.get("stations") + for station in all_stations: + external_data = self._get_external_data(station, path) # test data + if external_data is not None: + external_data.coords[self.model_type_dim] = [ + {self.forecast_indicator: self.model_display_name}.get(n, n) + for n in external_data.coords[self.model_type_dim].values] + external_data_expd = external_data.assign_coords({self.iter_dim: station}) + external_data_expd = external_data_expd.expand_dims(self.iter_dim).to_dataset(self.iter_dim) + external_data_expd.to_netcdf(os.path.join(path, f"forecasts_ds_{str(station)}_test.nc")) + + def calculate_error_metrics_based_on_upstream_wind_dir(self): + self.load_forecast_array_and_store_as_dataset() + path = self.data_store.get("forecast_path") + files = os.path.join(path, "forecasts_ds_*_test.nc") + ds = xr.open_mfdataset(files) + ds = ds.to_array(self.iter_dim) + wind_sectors = self.data_store.get("wind_sectors", "general") + sector_collector = dict() + h_sector_skill_scores = [] + for sec in wind_sectors: + h_sector_skill_scores.append(statistics.skill_score_based_on_mse( + ds.where(self.upstream_wind_sector.squeeze() == sec), + obs_name=self.observation_indicator, pred_name=self.model_display_name, + ref_name="ols").assign_coords({"sector": sec}) + ) + + #sec_coords = da.argwhere(self.upstream_wind_sector.squeeze() == sec) + #sector_collector[sec] = dict() + #for i, dim in enumerate(self.upstream_wind_sector.squeeze().dims): + # sector_collector[sec][dim] = sec_coords[:,i] + sector_skill_scores = xr.concat(h_sector_skill_scores, dim="sector") + #.to_dataframe("data")[["data"]].stack(level=0).reset_index(level=2, drop=True).reset_index(name="data") + return sector_skill_scores + + def calculate_error_metrics(self) -> Tuple[Dict, Dict, Dict, Dict]: """ Calculate error metrics and skill scores of NN forecast. @@ -900,6 +992,8 @@ class PostProcessing(RunEnvironment): # test errors if external_data is not None: + external_data.coords[self.model_type_dim] = [{self.forecast_indicator: self.model_display_name}.get(n, n) + for n in external_data.coords[self.model_type_dim].values] model_type_list = external_data.coords[self.model_type_dim].values.tolist() for model_type in remove_items(model_type_list, self.observation_indicator): if model_type not in errors.keys(): @@ -918,7 +1012,7 @@ class PostProcessing(RunEnvironment): model_list = None # test errors of competitors - for model_type in remove_items(model_list or [], list(errors.keys())): + for model_type in (model_list or []): if self.observation_indicator not in combined.coords[self.model_type_dim]: continue if model_type not in errors.keys(): @@ -971,15 +1065,17 @@ class PostProcessing(RunEnvironment): """Create a csv file containing all results from feature importance.""" report_path = os.path.join(self.data_store.get("experiment_path"), "latex_report") path_config.check_path_and_create(report_path) - res = [[self.model_type_dim, "method", "station", self.boot_var_dim, self.ahead_dim, "vals"]] + res = [] for boot_type, d0 in results.items(): for boot_method, d1 in d0.items(): for station_name, vals in d1.items(): for boot_var in vals.coords[self.boot_var_dim].values.tolist(): for ahead in vals.coords[self.ahead_dim].values.tolist(): res.append([boot_type, boot_method, station_name, boot_var, ahead, - float(vals.sel({self.boot_var_dim: boot_var, self.ahead_dim: ahead}))]) - col_names = res.pop(0) + *vals.sel({self.boot_var_dim: boot_var, + self.ahead_dim: ahead}).values.round(5).tolist()]) + col_names = [self.model_type_dim, "method", "station", self.boot_var_dim, self.ahead_dim, + *list(range(len(res[0]) - 5))] df = pd.DataFrame(res, columns=col_names) file_name = "feature_importance_skill_score_report_raw.csv" df.to_csv(os.path.join(report_path, file_name), sep=";") @@ -1014,8 +1110,8 @@ class PostProcessing(RunEnvironment): df.reindex(df.index.drop(["total"]).to_list() + ["total"], ) column_format = tables.create_column_format_for_tex(df) if model_type == "skill_score": - file_name = f"error_report_{model_type}_{metric}.%s".replace(' ', '_') + file_name = f"error_report_{model_type}_{metric}.%s".replace(' ', '_').replace('/', '_') else: - file_name = f"error_report_{metric}_{model_type}.%s".replace(' ', '_') + file_name = f"error_report_{metric}_{model_type}.%s".replace(' ', '_').replace('/', '_') tables.save_to_tex(report_path, file_name % "tex", column_format=column_format, df=df) tables.save_to_md(report_path, file_name % "md", df=df) diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py index f0192e469c0a52a68fc5b29a5254583d5207cc32..c094d1af42731348e120e0dcddebaaf9439435bb 100644 --- a/mlair/run_modules/pre_processing.py +++ b/mlair/run_modules/pre_processing.py @@ -246,7 +246,7 @@ class PreProcessing(RunEnvironment): # start station check collection = DataCollection(name=set_name) valid_stations = [] - kwargs = self.data_store.create_args_dict(data_handler.requirements(), scope=set_name) + kwargs = self.data_store.create_args_dict(data_handler.requirements(skip_args="station"), scope=set_name) use_multiprocessing = self.data_store.get("use_multiprocessing") tmp_path = self.data_store.get("tmp_path") @@ -270,6 +270,7 @@ class PreProcessing(RunEnvironment): collection.add(dh) valid_stations.append(s) pool.close() + pool.join() else: # serial solution logging.info("use serial validate station approach") kwargs.update({"return_strategy": "result"}) @@ -298,12 +299,43 @@ class PreProcessing(RunEnvironment): self.data_store.set(k, v) def transformation(self, data_handler: AbstractDataHandler, stations): + calculate_fresh_transformation = self.data_store.get_default("calculate_fresh_transformation", True) if hasattr(data_handler, "transformation"): - kwargs = self.data_store.create_args_dict(data_handler.requirements(), scope="train") - tmp_path = self.data_store.get_default("tmp_path", default=None) - transformation_dict = data_handler.transformation(stations, tmp_path=tmp_path, **kwargs) - if transformation_dict is not None: - self.data_store.set("transformation", transformation_dict) + transformation_opts = None if calculate_fresh_transformation is True else self._load_transformation() + if transformation_opts is None: + logging.info(f"start to calculate transformation parameters.") + kwargs = self.data_store.create_args_dict(data_handler.requirements(skip_args="station"), scope="train") + tmp_path = self.data_store.get_default("tmp_path", default=None) + transformation_opts = data_handler.transformation(stations, tmp_path=tmp_path, **kwargs) + else: + logging.info("In case no valid train data could be found due to problems with transformation, please " + "check your provided transformation file for compability with your data.") + self.data_store.set("transformation", transformation_opts) + if transformation_opts is not None: + self._store_transformation(transformation_opts) + + def _load_transformation(self): + """Try to load transformation options from file if transformation_file is provided.""" + transformation_file = self.data_store.get_default("transformation_file", None) + if transformation_file is not None: + if os.path.exists(transformation_file): + logging.info(f"use transformation from given transformation file: {transformation_file}") + with open(transformation_file, "rb") as pickle_file: + return dill.load(pickle_file) + else: + logging.info(f"cannot load transformation file: {transformation_file}. Use fresh calculation of " + f"transformation from train data.") + + def _store_transformation(self, transformation_opts): + """Store transformation options locally inside experiment_path if not exists already.""" + experiment_path = self.data_store.get("experiment_path") + transformation_path = os.path.join(experiment_path, "data", "transformation") + transformation_file = os.path.join(transformation_path, "transformation.pickle") + if not os.path.exists(transformation_file): + path_config.check_path_and_create(transformation_path) + with open(transformation_file, "wb") as f: + dill.dump(transformation_opts, f, protocol=4) + logging.info(f"Store transformation options locally for later use at: {transformation_file}") def prepare_competitors(self): """ diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py index 0696c2e7b8daa75925cf16096e183de94c21fe85..a38837dce041295d37fae1ea86ef2a215d51dc89 100644 --- a/mlair/run_modules/training.py +++ b/mlair/run_modules/training.py @@ -14,6 +14,7 @@ import psutil import pandas as pd from mlair.data_handler import KerasIterator +from mlair.model_modules import AbstractModelClass from mlair.model_modules.keras_extensions import CallbackHandler from mlair.plotting.training_monitoring import PlotModelHistory, PlotModelLearningRate from mlair.run_modules.run_environment import RunEnvironment @@ -67,10 +68,10 @@ class Training(RunEnvironment): def __init__(self): """Set up and run training.""" super().__init__() - self.model: keras.Model = self.data_store.get("model", "model") + self.model: AbstractModelClass = self.data_store.get("model", "model") self.train_set: Union[KerasIterator, None] = None self.val_set: Union[KerasIterator, None] = None - self.test_set: Union[KerasIterator, None] = None + # self.test_set: Union[KerasIterator, None] = None self.batch_size = self.data_store.get("batch_size") self.epochs = self.data_store.get("epochs") self.callbacks: CallbackHandler = self.data_store.get("callbacks", "model") @@ -81,9 +82,9 @@ class Training(RunEnvironment): def _run(self) -> None: """Run training. Details in class description.""" - self.set_generators() self.make_predict_function() if self._train_model: + self.set_generators() self.train() self.save_model() self.report_training() @@ -118,7 +119,9 @@ class Training(RunEnvironment): The called sub-method will automatically distribute the data according to the batch size. The subsets can be accessed as class variables train_set, val_set, and test_set. """ - for mode in ["train", "val", "test"]: + logging.info("set generators for training and validation") + # for mode in ["train", "val", "test"]: + for mode in ["train", "val"]: self._set_gen(mode) def train(self) -> None: @@ -149,7 +152,7 @@ class Training(RunEnvironment): logging.info("Found locally stored model and checkpoints. Training is resumed from the last checkpoint.") self.callbacks.load_callbacks() self.callbacks.update_checkpoint() - self.model = keras.models.load_model(checkpoint.filepath) + self.model.load_model(checkpoint.filepath, compile=True) hist: History = self.callbacks.get_callback_by_name("hist") initial_epoch = max(hist.epoch) + 1 _ = self.model.fit(self.train_set, @@ -179,6 +182,7 @@ class Training(RunEnvironment): model_name = self.data_store.get("model_name", "model") logging.debug(f"save best model to {model_name}") self.model.save(model_name, save_format='h5') + self.model.save(model_name) self.data_store.set("best_model", self.model) def load_best_model(self, name: str) -> None: @@ -189,8 +193,8 @@ class Training(RunEnvironment): """ logging.debug(f"load best model: {name}") try: - self.model.load_weights(name) - logging.info('reload weights...') + self.model.load_model(name, compile=True) + logging.info('reload model...') except OSError: logging.info('no weights to reload...') @@ -235,9 +239,11 @@ class Training(RunEnvironment): if multiple_branches_used: filename = os.path.join(path, f"{name}_history_main_loss.pdf") PlotModelHistory(filename=filename, history=history, main_branch=True) - if len([e for e in history.model.metrics_names if "mean_squared_error" in e]) > 0: + mse_indicator = list(set(history.model.metrics_names).intersection(["mean_squared_error", "mse"])) + if len(mse_indicator) > 0: filename = os.path.join(path, f"{name}_history_main_mse.pdf") - PlotModelHistory(filename=filename, history=history, plot_metric="mse", main_branch=multiple_branches_used) + PlotModelHistory(filename=filename, history=history, plot_metric=mse_indicator[0], + main_branch=multiple_branches_used) # plot learning rate if lr_sc: diff --git a/requirements.txt b/requirements.txt index 8d21c80db974033c94985821564e26cbb4aa8088..3afc17b67fddbf5a269df1e1b7e103045630a290 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,23 +1,34 @@ -## this list was generated using pipreqs on mlair/ directory astropy==4.1 auto_mix_prep==0.2.0 Cartopy==0.18.0 +dask==2021.3.0 dill==0.3.3 +fsspec==2021.11.0 keras==2.6.0 keras_nightly==2.5.0.dev2021032900 +locket==0.2.1 matplotlib==3.3.4 mock==4.0.3 +netcdf4==1.5.8 numpy==1.19.5 pandas==1.1.5 +partd==1.2.0 psutil==5.8.0 +pydot==1.4.2 pytest==6.2.2 +pytest-cov==2.11.1 +pytest-html==3.1.1 +pytest-lazy-fixture==0.6.3 requests==2.25.1 scipy==1.5.2 seaborn==0.11.1 setuptools==47.1.0 +--no-binary shapely Shapely==1.8.0 six==1.15.0 statsmodels==0.12.2 +tabulate==0.8.9 tensorflow==2.5.0 +toolz==0.11.2 typing_extensions==3.7.4.3 wget==3.2 xarray==0.16.2 diff --git a/requirements_vm_local.txt b/requirements_vm_local.txt deleted file mode 100644 index d57cfb8e0b75055e187816b9922f72ac510cbd7d..0000000000000000000000000000000000000000 --- a/requirements_vm_local.txt +++ /dev/null @@ -1,103 +0,0 @@ -absl-py==0.11.0 -appdirs==1.4.4 -astor==0.8.1 -astropy==4.1 -astunparse==1.6.3 -attrs==20.3.0 -Bottleneck==1.3.2 -cached-property==1.5.2 -cachetools==4.2.4 -Cartopy==0.18.0 -certifi==2020.12.5 -cftime==1.4.1 -chardet==4.0.0 -click==8.0.3 -cloudpickle==2.0.0 -coverage==5.4 -cycler==0.10.0 -dask==2021.10.0 -dill==0.3.3 -distributed==2021.10.0 -flatbuffers==1.12 -fsspec==0.8.5 -gast==0.4.0 -google-auth==2.3.0 -google-auth-oauthlib==0.4.6 -google-pasta==0.2.0 -greenlet==1.1.2 -grpcio==1.34.0 -h5py==3.1.0 -HeapDict==1.0.1 -idna==2.10 -importlib-metadata==3.4.0 -iniconfig==1.1.1 -Jinja2==3.0.2 -joblib==1.1.0 -keras-nightly==2.5.0.dev2021032900 -Keras-Preprocessing==1.1.2 -kiwisolver==1.3.1 -locket==0.2.1 -Markdown==3.3.3 -MarkupSafe==2.0.1 -matplotlib==3.3.4 -mock==4.0.3 -msgpack==1.0.2 -netCDF4==1.5.5.1 -numpy==1.19.5 -oauthlib==3.1.1 -opt-einsum==3.3.0 -ordered-set==4.0.2 -packaging==20.9 -pandas==1.1.5 -partd==1.1.0 -patsy==0.5.1 -Pillow==8.1.0 -pluggy==0.13.1 -protobuf==3.15.0 -psutil==5.8.0 -py==1.10.0 -pyasn1==0.4.8 -pyasn1-modules==0.2.8 -pydot==1.4.2 -pyparsing==2.4.7 -pyshp==2.1.3 -pytest==6.2.2 -pytest-cov==2.11.1 -pytest-html==3.1.1 -pytest-lazy-fixture==0.6.3 -pytest-metadata==1.11.0 -pytest-sugar==0.9.4 -python-dateutil==2.8.1 -pytz==2021.1 -PyYAML==5.4.1 -requests==2.25.1 -requests-oauthlib==1.3.0 -rsa==4.7.2 -scikit-learn==1.0.1 -scipy==1.5.2 -seaborn==0.11.1 -Shapely==1.7.1 -six==1.15.0 -sortedcontainers==2.4.0 -SQLAlchemy==1.4.26 -statsmodels==0.12.2 -tabulate==0.8.8 -tblib==1.7.0 -tensorboard==2.7.0 -tensorboard-data-server==0.6.1 -tensorboard-plugin-wit==1.8.0 -tensorflow==2.5.0 -tensorflow-estimator==2.5.0 -termcolor==1.1.0 -threadpoolctl==3.0.0 -toml==0.10.2 -toolz==0.11.1 -tornado==6.1 -typing-extensions==3.7.4.3 -urllib3==1.26.3 -Werkzeug==1.0.1 -wget==3.2 -wrapt==1.12.1 -xarray==0.16.2 -zict==2.0.0 -zipp==3.4.0 diff --git a/run.py b/run.py index 5324e55a09b004352c4e35f23f5e2ea21a7451d6..82bb0e2814d403b5be602eaebd1bc44b6cf6d6f9 100644 --- a/run.py +++ b/run.py @@ -30,7 +30,6 @@ def main(parser_args): train_model=False, create_new_model=True, network="UBA", evaluate_feature_importance=False, # plot_list=["PlotCompetitiveSkillScore"], competitors=["test_model", "test_model2"], - competitor_path=os.path.join(os.getcwd(), "data", "comp_test"), **parser_args.__dict__, start_script=__file__) workflow.run() diff --git a/run_mixed_sampling.py b/run_mixed_sampling.py index 784f653fbfb2eb4c78e6e858acf67cd0ae47a593..47aa9b970c0e95ccadb60e8c090136c0fa6ceea4 100644 --- a/run_mixed_sampling.py +++ b/run_mixed_sampling.py @@ -4,8 +4,8 @@ __date__ = '2019-11-14' import argparse from mlair.workflows import DefaultWorkflow -from mlair.data_handler.data_handler_mixed_sampling import DataHandlerMixedSampling, DataHandlerMixedSamplingWithFilter, \ - DataHandlerSeparationOfScales +from mlair.data_handler.data_handler_mixed_sampling import DataHandlerMixedSampling + stats = {'o3': 'dma8eu', 'no': 'dma8eu', 'no2': 'dma8eu', 'relhum': 'average_values', 'u': 'average_values', 'v': 'average_values', @@ -20,7 +20,7 @@ data_origin = {'o3': '', 'no': '', 'no2': '', def main(parser_args): args = dict(stations=["DEBW107", "DEBW013"], network="UBA", - evaluate_feature_importance=False, plot_list=[], + evaluate_feature_importance=True, # plot_list=[], data_origin=data_origin, data_handler=DataHandlerMixedSampling, interpolation_limit=(3, 1), overwrite_local_data=False, sampling=("hourly", "daily"), @@ -28,8 +28,6 @@ def main(parser_args): create_new_model=True, train_model=False, epochs=1, window_history_size=6 * 24 + 16, window_history_offset=16, - kz_filter_length=[100 * 24, 15 * 24], - kz_filter_iter=[4, 5], start="2006-01-01", train_start="2006-01-01", end="2011-12-31", diff --git a/test/test_configuration/test_defaults.py b/test/test_configuration/test_defaults.py index 8644181185203186bb6c8549e8faa99e75a31a81..f6bc6d24724c2620083602d3864bcbca0a709681 100644 --- a/test/test_configuration/test_defaults.py +++ b/test/test_configuration/test_defaults.py @@ -72,8 +72,6 @@ class TestAllDefaults: assert DEFAULT_FEATURE_IMPORTANCE_BOOTSTRAP_TYPE == "singleinput" assert DEFAULT_FEATURE_IMPORTANCE_BOOTSTRAP_METHOD == "shuffle" assert DEFAULT_PLOT_LIST == ["PlotMonthlySummary", "PlotStationMap", "PlotClimatologicalSkillScore", - "PlotTimeSeries", "PlotCompetitiveSkillScore", "PlotBootstrapSkillScore", + "PlotTimeSeries", "PlotCompetitiveSkillScore", "PlotFeatureImportanceSkillScore", "PlotConditionalQuantiles", "PlotAvailability", "PlotAvailabilityHistogram", "PlotDataHistogram", "PlotPeriodogram", "PlotSampleUncertaintyFromBootstrap"] - - diff --git a/test/test_data_handler/test_data_handler.py b/test/test_data_handler/test_abstract_data_handler.py similarity index 90% rename from test/test_data_handler/test_data_handler.py rename to test/test_data_handler/test_abstract_data_handler.py index 418c7946efe160c9bbfeccff9908a6cf17dec17f..5166717471cb9b98a53cc33462fd65e13d142b5b 100644 --- a/test/test_data_handler/test_data_handler.py +++ b/test/test_data_handler/test_abstract_data_handler.py @@ -4,11 +4,12 @@ import inspect from mlair.data_handler.abstract_data_handler import AbstractDataHandler -class TestDefaultDataHandler: +class TestAbstractDataHandler: def test_required_attributes(self): dh = AbstractDataHandler assert hasattr(dh, "_requirements") + assert hasattr(dh, "_skip_args") assert hasattr(dh, "__init__") assert hasattr(dh, "build") assert hasattr(dh, "requirements") @@ -35,8 +36,12 @@ class TestDefaultDataHandler: def test_own_args(self): dh = AbstractDataHandler() assert isinstance(dh.own_args(), list) - assert len(dh.own_args()) == 0 - assert "self" not in dh.own_args() + assert len(dh.own_args()) == 1 + assert "self" in dh.own_args() + + def test_skip_args(self): + dh = AbstractDataHandler() + assert dh._skip_args == ["self"] def test_transformation(self): assert AbstractDataHandler.transformation() is None diff --git a/test/test_data_handler/test_data_handler_mixed_sampling.py b/test/test_data_handler/test_data_handler_mixed_sampling.py index 7418a435008f06a9016f903fe140b51d0a7c8106..0515278a8ae77880de99b0de4abf7fa85198fe49 100644 --- a/test/test_data_handler/test_data_handler_mixed_sampling.py +++ b/test/test_data_handler/test_data_handler_mixed_sampling.py @@ -2,13 +2,16 @@ __author__ = 'Lukas Leufen' __date__ = '2020-12-10' from mlair.data_handler.data_handler_mixed_sampling import DataHandlerMixedSampling, \ - DataHandlerMixedSamplingSingleStation, DataHandlerMixedSamplingWithKzFilter, \ - DataHandlerMixedSamplingWithKzFilterSingleStation, DataHandlerSeparationOfScales, \ - DataHandlerSeparationOfScalesSingleStation, DataHandlerMixedSamplingWithFilterSingleStation -from mlair.data_handler.data_handler_with_filter import DataHandlerKzFilterSingleStation + DataHandlerMixedSamplingSingleStation, DataHandlerMixedSamplingWithFilterSingleStation, \ + DataHandlerMixedSamplingWithFirFilterSingleStation, DataHandlerMixedSamplingWithFirFilter, \ + DataHandlerFirFilterSingleStation, DataHandlerMixedSamplingWithClimateFirFilterSingleStation, \ + DataHandlerMixedSamplingWithClimateFirFilter +from mlair.data_handler.data_handler_with_filter import DataHandlerFilter, DataHandlerFilterSingleStation, \ + DataHandlerClimateFirFilterSingleStation from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation -from mlair.helpers import remove_items +from mlair.data_handler.default_data_handler import DefaultDataHandler from mlair.configuration.defaults import DEFAULT_INTERPOLATION_METHOD +from mlair.helpers.testing import get_all_args import pytest import mock @@ -25,17 +28,23 @@ class TestDataHandlerMixedSampling: assert obj.data_handler_transformation.__qualname__ == DataHandlerMixedSamplingSingleStation.__qualname__ def test_requirements(self): + reqs = get_all_args(DefaultDataHandler) obj = object.__new__(DataHandlerMixedSampling) - req = object.__new__(DataHandlerSingleStation) - assert sorted(obj._requirements) == sorted(remove_items(req.requirements(), "station")) + assert sorted(obj.own_args()) == reqs + reqs = get_all_args(DataHandlerSingleStation, remove="self") + assert sorted(obj._requirements) == reqs + reqs = get_all_args(DataHandlerSingleStation, DefaultDataHandler, remove=["self", "id_class"]) + assert sorted(obj.requirements()) == reqs class TestDataHandlerMixedSamplingSingleStation: def test_requirements(self): + reqs = get_all_args(DataHandlerSingleStation) obj = object.__new__(DataHandlerMixedSamplingSingleStation) - req = object.__new__(DataHandlerSingleStation) - assert sorted(obj._requirements) == sorted(remove_items(req.requirements(), "station")) + assert sorted(obj.own_args()) == reqs + reqs = get_all_args(DataHandlerSingleStation, remove="self") + assert sorted(obj.requirements()) == reqs @mock.patch("mlair.data_handler.data_handler_single_station.DataHandlerSingleStation.setup_samples") def test_init(self, mock_super_init): @@ -86,45 +95,97 @@ class TestDataHandlerMixedSamplingSingleStation: pass -class TestDataHandlerMixedSamplingWithKzFilter: +class TestDataHandlerMixedSamplingWithFilterSingleStation: - def test_data_handler(self): - obj = object.__new__(DataHandlerMixedSamplingWithKzFilter) - assert obj.data_handler.__qualname__ == DataHandlerMixedSamplingWithKzFilterSingleStation.__qualname__ + def test_requirements(self): - def test_data_handler_transformation(self): - obj = object.__new__(DataHandlerMixedSamplingWithKzFilter) - assert obj.data_handler_transformation.__qualname__ == DataHandlerMixedSamplingWithKzFilterSingleStation.__qualname__ + reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerSingleStation) + obj = object.__new__(DataHandlerMixedSamplingWithFilterSingleStation) + assert sorted(obj.own_args()) == reqs + reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerSingleStation, remove="self") + assert sorted(obj._requirements) == [] + assert sorted(obj.requirements()) == reqs - def test_requirements(self): - obj = object.__new__(DataHandlerMixedSamplingWithKzFilter) - req1 = object.__new__(DataHandlerMixedSamplingWithFilterSingleStation) - req2 = object.__new__(DataHandlerKzFilterSingleStation) - req = list(set(req1.requirements() + req2.requirements())) - assert sorted(obj._requirements) == sorted(remove_items(req, "station")) +class TestDataHandlerMixedSamplingWithFirFilter: -class TestDataHandlerMixedSamplingWithFilterSingleStation: - pass + def test_requirements(self): + reqs = get_all_args(DataHandlerFilter, DefaultDataHandler) + obj = object.__new__(DataHandlerMixedSamplingWithFirFilter) + assert sorted(obj.own_args()) == reqs + reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerSingleStation, DataHandlerFirFilterSingleStation, + remove=["self"]) + assert sorted(obj._requirements) == reqs + reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerSingleStation, DataHandlerFilter, + DataHandlerFirFilterSingleStation, DefaultDataHandler, remove=["self", "id_class"]) + assert sorted(obj.requirements()) == reqs -class TestDataHandlerSeparationOfScales: +class TestDataHandlerMixedSamplingWithFirFilterSingleStation: - def test_data_handler(self): - obj = object.__new__(DataHandlerSeparationOfScales) - assert obj.data_handler.__qualname__ == DataHandlerSeparationOfScalesSingleStation.__qualname__ + def test_requirements(self): + reqs = get_all_args(DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, DataHandlerSingleStation) + obj = object.__new__(DataHandlerMixedSamplingWithFirFilterSingleStation) + assert sorted(obj.own_args()) == reqs + assert sorted(obj._requirements) == [] + reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerFirFilterSingleStation, DataHandlerSingleStation, + remove="self") + assert sorted(obj.requirements()) == reqs - def test_data_handler_transformation(self): - obj = object.__new__(DataHandlerSeparationOfScales) - assert obj.data_handler_transformation.__qualname__ == DataHandlerSeparationOfScalesSingleStation.__qualname__ + +class TestDataHandlerMixedSamplingWithClimateFirFilter: def test_requirements(self): - obj = object.__new__(DataHandlerMixedSamplingWithKzFilter) - req1 = object.__new__(DataHandlerMixedSamplingWithFilterSingleStation) - req2 = object.__new__(DataHandlerKzFilterSingleStation) - req = list(set(req1.requirements() + req2.requirements())) - assert sorted(obj._requirements) == sorted(remove_items(req, "station")) + reqs = get_all_args(DataHandlerMixedSamplingWithClimateFirFilter, DataHandlerFilter, DefaultDataHandler) + obj = object.__new__(DataHandlerMixedSamplingWithClimateFirFilter) + assert sorted(obj.own_args()) == reqs + reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerClimateFirFilterSingleStation, + DataHandlerSingleStation, DataHandlerFirFilterSingleStation, remove=["self"]) + assert sorted(obj._requirements) == reqs + reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerClimateFirFilterSingleStation, + DataHandlerSingleStation, DataHandlerFilter, DataHandlerMixedSamplingWithClimateFirFilter, + DefaultDataHandler, DataHandlerFirFilterSingleStation, remove=["self", "id_class"]) + assert sorted(obj.requirements()) == reqs -class TestDataHandlerSeparationOfScalesSingleStation: - pass +class TestDataHandlerMixedSamplingWithClimateFirFilterSingleStation: + + def test_requirements(self): + reqs = get_all_args(DataHandlerClimateFirFilterSingleStation, DataHandlerFirFilterSingleStation, + DataHandlerFilterSingleStation, DataHandlerSingleStation) + obj = object.__new__(DataHandlerMixedSamplingWithClimateFirFilterSingleStation) + assert sorted(obj.own_args()) == reqs + assert sorted(obj._requirements) == [] + reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerFirFilterSingleStation, DataHandlerSingleStation, + DataHandlerClimateFirFilterSingleStation, remove="self") + assert sorted(obj.requirements()) == reqs + + +# class TestDataHandlerSeparationOfScales: +# +# def test_data_handler(self): +# obj = object.__new__(DataHandlerSeparationOfScales) +# assert obj.data_handler.__qualname__ == DataHandlerSeparationOfScalesSingleStation.__qualname__ +# +# def test_data_handler_transformation(self): +# obj = object.__new__(DataHandlerSeparationOfScales) +# assert obj.data_handler_transformation.__qualname__ == DataHandlerSeparationOfScalesSingleStation.__qualname__ +# +# def test_requirements(self): +# reqs = get_all_args(DefaultDataHandler) +# obj = object.__new__(DataHandlerSeparationOfScales) +# assert sorted(obj.own_args()) == reqs +# +# reqs = get_all_args(DataHandlerSeparationOfScalesSingleStation, DataHandlerKzFilterSingleStation, +# DataHandlerMixedSamplingWithKzFilterSingleStation,DataHandlerFilterSingleStation, +# DataHandlerSingleStation, remove=["self", "id_class"]) +# assert sorted(obj._requirements) == reqs +# reqs = get_all_args(DataHandlerSeparationOfScalesSingleStation, DataHandlerKzFilterSingleStation, +# DataHandlerMixedSamplingWithKzFilterSingleStation,DataHandlerFilterSingleStation, +# DataHandlerSingleStation, DefaultDataHandler, remove=["self", "id_class"]) +# assert sorted(obj.requirements()) == reqs + +# +# class TestDataHandlerSeparationOfScalesSingleStation: +# pass + diff --git a/test/test_data_handler/test_data_handler_with_filter.py b/test/test_data_handler/test_data_handler_with_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..b83effd96ec7a496977873af0785a8406fa7114e --- /dev/null +++ b/test/test_data_handler/test_data_handler_with_filter.py @@ -0,0 +1,87 @@ +import pytest + +from mlair.data_handler.data_handler_with_filter import DataHandlerFilter, DataHandlerFilterSingleStation, \ + DataHandlerFirFilter, DataHandlerFirFilterSingleStation, DataHandlerClimateFirFilter, \ + DataHandlerClimateFirFilterSingleStation +from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation +from mlair.data_handler.default_data_handler import DefaultDataHandler +from mlair.helpers.testing import get_all_args + + +class TestDataHandlerFilter: + + def test_requirements(self): + reqs = get_all_args(DataHandlerFilter, DefaultDataHandler) + obj = object.__new__(DataHandlerFilter) + assert sorted(obj.own_args()) == reqs + reqs = get_all_args(DataHandlerSingleStation, DataHandlerFilterSingleStation, remove=["self"]) + assert sorted(obj._requirements) == reqs + reqs = get_all_args(DataHandlerSingleStation, DataHandlerFilterSingleStation, DefaultDataHandler, + DataHandlerFilter, remove=["self", "id_class"]) + assert sorted(obj.requirements()) == reqs + + +class TestDataHandlerFilterSingleStation: + + def test_requirements(self): + reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerSingleStation) + obj = object.__new__(DataHandlerFilterSingleStation) + assert sorted(obj.own_args()) == reqs + assert sorted(obj._requirements) == [] + reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerSingleStation, remove="self") + assert sorted(obj.requirements()) == reqs + + +class TestDataHandlerFirFilter: + + def test_requirements(self): + reqs = get_all_args(DataHandlerFilter, DefaultDataHandler) + obj = object.__new__(DataHandlerFirFilter) + assert sorted(obj.own_args()) == reqs + reqs = get_all_args(DataHandlerSingleStation, DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, + remove=["self"]) + assert sorted(obj._requirements) == reqs + reqs = get_all_args(DataHandlerSingleStation, DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, + DataHandlerFilter, DefaultDataHandler, remove=["self", "id_class"]) + assert sorted(obj.requirements()) == reqs + + +class TestDataHandlerFirFilterSingleStation: + + def test_requirements(self): + + reqs = get_all_args(DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, DataHandlerSingleStation) + obj = object.__new__(DataHandlerFirFilterSingleStation) + assert sorted(obj.own_args()) == reqs + assert sorted(obj._requirements) == [] + reqs = get_all_args(DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, DataHandlerSingleStation, + remove="self") + assert sorted(obj.requirements()) == reqs + + +class TestDataHandlerClimateFirFilter: + + def test_requirements(self): + reqs = get_all_args(DataHandlerFilter, DefaultDataHandler) + obj = object.__new__(DataHandlerClimateFirFilter) + assert sorted(obj.own_args()) == reqs + reqs = get_all_args(DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, DataHandlerSingleStation, + DataHandlerClimateFirFilterSingleStation, remove="self") + assert sorted(obj._requirements) == reqs + reqs = get_all_args(DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, DataHandlerSingleStation, + DataHandlerClimateFirFilterSingleStation, DefaultDataHandler, DataHandlerFilter, + remove=["self", "id_class"]) + assert sorted(obj.requirements()) == reqs + + +class TestDataHandlerClimateFirFilterSingleStation: + + def test_requirements(self): + reqs = get_all_args(DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, DataHandlerSingleStation, + DataHandlerClimateFirFilterSingleStation) + obj = object.__new__(DataHandlerClimateFirFilterSingleStation) + assert sorted(obj.own_args()) == reqs + assert sorted(obj._requirements) == [] + reqs = get_all_args(DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, DataHandlerSingleStation, + DataHandlerClimateFirFilterSingleStation, remove="self") + assert sorted(obj.requirements()) == reqs diff --git a/test/test_data_handler/test_default_data_handler.py b/test/test_data_handler/test_default_data_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..1e0a5db3d82bf528bfeef321799841588e2d5678 --- /dev/null +++ b/test/test_data_handler/test_default_data_handler.py @@ -0,0 +1,23 @@ +import pytest +from mlair.data_handler.default_data_handler import DefaultDataHandler +from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation +from mlair.helpers.testing import get_all_args + + +class TestDefaultDataHandler: + + def test_requirements(self): + reqs = get_all_args(DefaultDataHandler) + obj = object.__new__(DefaultDataHandler) + assert sorted(obj.own_args()) == reqs + reqs = get_all_args(DataHandlerSingleStation, remove="self") + assert sorted(obj._requirements) == reqs + reqs = get_all_args(DefaultDataHandler, DataHandlerSingleStation, remove=["self", "id_class"]) + assert sorted(obj.requirements()) == reqs + reqs = get_all_args(DefaultDataHandler, DataHandlerSingleStation, remove=["self", "id_class", "station"]) + assert sorted(obj.requirements(skip_args="station")) == reqs + + + + + diff --git a/test/test_data_handler/test_default_data_handler_single_station.py b/test/test_data_handler/test_default_data_handler_single_station.py new file mode 100644 index 0000000000000000000000000000000000000000..fea8f9cbddea4cdac350bc9df2c60c8e3a2e7399 --- /dev/null +++ b/test/test_data_handler/test_default_data_handler_single_station.py @@ -0,0 +1,15 @@ +import pytest +from mlair.data_handler.default_data_handler import DefaultDataHandler +from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation +from mlair.helpers.testing import get_all_args +from mlair.helpers import remove_items + + +class TestDataHandlerSingleStation: + + def test_requirements(self): + reqs = get_all_args(DataHandlerSingleStation) + obj = object.__new__(DataHandlerSingleStation) + assert sorted(obj.own_args()) == reqs + assert obj._requirements == [] + assert sorted(obj.requirements()) == remove_items(reqs, "self") diff --git a/test/test_helpers/test_filter.py b/test/test_helpers/test_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..519a36b3438cafd041cc65c43572fc026eced4dd --- /dev/null +++ b/test/test_helpers/test_filter.py @@ -0,0 +1,403 @@ +__author__ = 'Lukas Leufen' +__date__ = '2021-11-18' + +import pytest +import inspect +import numpy as np +import xarray as xr +import pandas as pd + +from mlair.helpers.filter import ClimateFIRFilter, filter_width_kzf, firwin_kzf, omega_null_kzf, fir_filter_convolve + + +class TestClimateFIRFilter: + + @pytest.fixture + def var_dim(self): + return "variables" + + @pytest.fixture + def time_dim(self): + return "datetime" + + @pytest.fixture + def data(self): + pos = np.linspace(0, 4, num=100) + return np.cos(pos * np.pi) + + @pytest.fixture + def xr_array(self, data, time_dim): + start = np.datetime64("2010-01-01 00:00") + time_index = [start + np.timedelta64(h, "h") for h in range(len(data))] + array = xr.DataArray(data.reshape(len(data), 1), dims=[time_dim, "station"], + coords={time_dim: time_index, "station": ["DE266X"]}) + return array + + @pytest.fixture + def xr_array_long(self, data, time_dim): + start = np.datetime64("2010-01-01 00:00") + time_index = [start + np.timedelta64(175 * h, "h") for h in range(len(data))] + array = xr.DataArray(data.reshape(len(data), 1), dims=[time_dim, "station"], + coords={time_dim: time_index, "station": ["DE266X"]}) + return array + + @pytest.fixture + def xr_array_long_with_var(self, data, time_dim, var_dim): + start = np.datetime64("2010-01-01 00:00") + time_index = [start + np.timedelta64(175 * h, "h") for h in range(len(data))] + array = xr.DataArray(data.reshape(*data.shape, 1), dims=[time_dim, "station"], + coords={time_dim: time_index, "station": ["DE266X"]}) + array = array.resample({time_dim: "1H"}).interpolate() + new_data = xr.concat([array, + array + np.sin(np.arange(array.shape[0]) * 2 * np.pi / 24).reshape(*array.shape), + array + np.random.random(size=array.shape), + array * np.random.random(size=array.shape)], + dim=pd.Index(["o3", "temp", "wind", "sun"], name=var_dim)) + return new_data + + def test_combine_observation_and_apriori_no_new_dim(self, xr_array, time_dim): + obj = object.__new__(ClimateFIRFilter) + apriori = xr.ones_like(xr_array) + res = obj.combine_observation_and_apriori(xr_array, apriori, time_dim, "window", 20, 10) + assert res.coords[time_dim].values[0] == xr_array.coords[time_dim].values[20] + first_entry = res.sel({time_dim: res.coords[time_dim].values[0]}) + assert np.testing.assert_array_equal(first_entry.sel(window=range(-20, 1)).values, xr_array.values[:21]) is None + assert np.testing.assert_array_equal(first_entry.sel(window=range(1, 10)).values, apriori.values[21:30]) is None + + def test_combine_observation_and_apriori_with_new_dim(self, xr_array, time_dim): + obj = object.__new__(ClimateFIRFilter) + apriori = xr.ones_like(xr_array) + xr_array = obj._shift_data(xr_array, range(-20, 1), time_dim, new_dim="window") + apriori = obj._shift_data(apriori, range(1, 10), time_dim, new_dim="window") + res = obj.combine_observation_and_apriori(xr_array, apriori, time_dim, "window", 10, 10) + assert res.coords[time_dim].values[0] == xr_array.coords[time_dim].values[10] + date_pos = res.coords[time_dim].values[0] + first_entry = res.sel({time_dim: date_pos}) + assert xr.testing.assert_equal(first_entry.sel(window=range(-10, 1)), + xr_array.sel({time_dim: date_pos, "window": range(-10, 1)})) is None + assert xr.testing.assert_equal(first_entry.sel(window=range(1, 10)), apriori.sel({time_dim: date_pos})) is None + + def test_shift_data(self, xr_array, time_dim): + remaining_dims = set(xr_array.dims).difference([time_dim]) + obj = object.__new__(ClimateFIRFilter) + index_values = range(-15, 1) + res = obj._shift_data(xr_array, index_values, time_dim, new_dim="window") + assert len(res.dims) == len(remaining_dims) + 2 + assert len(set(res.dims).difference([time_dim, "window", *remaining_dims])) == 0 + assert np.testing.assert_array_equal(res.coords["window"].values, np.arange(-15, 1)) is None + sel = res.sel({time_dim: res.coords[time_dim].values[15]}) + assert sel.sel(window=-15).values == xr_array.sel({time_dim: xr_array.coords[time_dim].values[0]}).values + assert sel.sel(window=0).values == xr_array.sel({time_dim: xr_array.coords[time_dim].values[15]}).values + + def test_create_index_array(self): + obj = object.__new__(ClimateFIRFilter) + index_name = "test_index_name" + index_values = range(-10, 1) + res = obj.create_index_array(index_name, index_values) + assert len(res.dims) == 1 + assert res.dims[0] == index_name + assert res.shape == (11,) + assert np.testing.assert_array_equal(res.values, np.arange(-10, 1)) is None + + def test_create_tmp_dimension(self, xr_array, time_dim): + obj = object.__new__(ClimateFIRFilter) + res = obj._create_tmp_dimension(xr_array) + assert res == "window" + xr_array = xr_array.rename({time_dim: "window"}) + res = obj._create_tmp_dimension(xr_array) + assert res == "windowwindow" + xr_array = xr_array.rename({"window": "windowwindow"}) + res = obj._create_tmp_dimension(xr_array) + assert res == "window" + + def test_create_tmp_dimension_iter_limit(self, xr_array, time_dim): + obj = object.__new__(ClimateFIRFilter) + dim_name = "window" + xr_array = xr_array.rename({time_dim: "window"}) + for i in range(11): + dim_name += dim_name + xr_array = xr_array.expand_dims(dim_name, -1) + with pytest.raises(ValueError) as e: + obj._create_tmp_dimension(xr_array) + assert "Could not create new dimension." in e.value.args[0] + + def test_minimum_length(self): + obj = object.__new__(ClimateFIRFilter) + res = obj._minimum_length([43], 15, 0, "hamming") + assert res == 15 + res = obj._minimum_length([43, 13], 15, 0, ("kaiser", 10)) + assert res == 28 + res = obj._minimum_length([43, 13], 15, 1, "hamming") + assert res == 15 + res = obj._minimum_length([128, 64, 43], None, 0, "hamming") + assert res == 64 + res = obj._minimum_length([43], None, 0, "hamming") + assert res is None + + def test_minimum_length_with_kzf(self): + obj = object.__new__(ClimateFIRFilter) + res = obj._minimum_length([(15, 5), (5, 3)], None, 0, "kzf") + assert res == 13 + + def test_calculate_filter_coefficients(self): + obj = object.__new__(ClimateFIRFilter) + res = obj._calculate_filter_coefficients("hamming", 20, 1, 24) + assert res.shape == (20,) + assert np.testing.assert_almost_equal(res.sum(), 1) is None + res = obj._calculate_filter_coefficients(("kaiser", 10), 20, 1, 24) + assert res.shape == (20,) + assert np.testing.assert_almost_equal(res.sum(), 1) is None + res = obj._calculate_filter_coefficients("kzf", (5, 5), 1, 24) + assert res.shape == (21,) + assert np.testing.assert_almost_equal(res.sum(), 1) is None + + def test_create_monthly_mean(self, xr_array_long, time_dim): + obj = object.__new__(ClimateFIRFilter) + res = obj.create_monthly_mean(xr_array_long, time_dim) + assert res.shape == (1462, 1) + assert np.datetime64("2008-12-16") == res.coords[time_dim][0].values + assert np.datetime64("2012-12-16") == res.coords[time_dim][-1].values + mean_jan = xr_array_long[xr_array_long[f"{time_dim}.month"] == 1].mean() + assert res.sel({time_dim: "2009-01-16"}) == mean_jan + mean_jul = xr_array_long[xr_array_long[f"{time_dim}.month"] == 7].mean() + assert res.sel({time_dim: "2009-07-16"}) == mean_jul + assert res.sel({time_dim: "2010-06-15"}) < res.sel({time_dim: "2010-06-16"}) + assert res.sel({time_dim: "2010-06-17"}) > res.sel({time_dim: "2010-06-16"}) + + def test_create_monthly_mean_sampling(self, xr_array_long, time_dim): + obj = object.__new__(ClimateFIRFilter) + res = obj.create_monthly_mean(xr_array_long, time_dim, sampling="1m") + assert res.shape == (49, 1) + res = obj.create_monthly_mean(xr_array_long, time_dim, sampling="1H") + assert res.shape == (35065, 1) + mean_jun = xr_array_long[xr_array_long[f"{time_dim}.month"] == 6].mean() + assert res.sel({time_dim: "2010-06-15T00:00:00"}) == mean_jun + assert res.sel({time_dim: "2011-06-15T00:00:00"}) == mean_jun + + def test_create_monthly_mean_sel_opts(self, xr_array_long, time_dim): + obj = object.__new__(ClimateFIRFilter) + sel_opts = {time_dim: slice("2010-05", "2010-08")} + res = obj.create_monthly_mean(xr_array_long, time_dim, sel_opts=sel_opts) + assert res.dropna(time_dim)[f"{time_dim}.month"].min() == 5 + assert res.dropna(time_dim)[f"{time_dim}.month"].max() == 8 + mean_jun_2010 = xr_array_long[xr_array_long[f"{time_dim}.month"] == 6].sel({time_dim: "2010"}).mean() + assert res.sel({time_dim: "2010-06-15T00:00:00"}) == mean_jun_2010 + + def test_compute_hourly_mean_per_month(self, xr_array_long, time_dim): + obj = object.__new__(ClimateFIRFilter) + xr_array_long = xr_array_long.resample({time_dim: "1H"}).interpolate() + res = obj._compute_hourly_mean_per_month(xr_array_long, time_dim, True) + assert len(res.keys()) == 12 + assert 6 in res.keys() + assert np.testing.assert_almost_equal(res[12].mean(), 0) is None + assert np.testing.assert_almost_equal(res[3].mean(), 0) is None + assert res[8].shape == (24, 1) + + def test_compute_hourly_mean_per_month_no_anomaly(self, xr_array_long, time_dim): + obj = object.__new__(ClimateFIRFilter) + xr_array_long = xr_array_long.resample({time_dim: "1H"}).interpolate() + res = obj._compute_hourly_mean_per_month(xr_array_long, time_dim, False) + assert len(res.keys()) == 12 + assert 9 in res.keys() + assert np.testing.assert_array_less(res[2], res[1]) is None + + def test_create_seasonal_cycle_of_hourly_mean(self, xr_array_long, time_dim): + obj = object.__new__(ClimateFIRFilter) + xr_array_long = xr_array_long.resample({time_dim: "1H"}).interpolate() + monthly = obj.create_monthly_unity_array(xr_array_long, time_dim) * np.nan + seasonal_hourly_means = obj._compute_hourly_mean_per_month(xr_array_long, time_dim, True) + res = obj._create_seasonal_cycle_of_single_hour_mean(monthly, seasonal_hourly_means, 0, time_dim, "1h") + assert res[f"{time_dim}.hour"].sum() == 0 + assert np.testing.assert_almost_equal(res.sel({time_dim: "2010-12-01"}), res.sel({time_dim: "2011-12-01"})) is None + res = obj._create_seasonal_cycle_of_single_hour_mean(monthly, seasonal_hourly_means, 13, time_dim, "1h") + assert res[f"{time_dim}.hour"].mean() == 13 + assert np.testing.assert_almost_equal(res.sel({time_dim: "2010-12-01"}), res.sel({time_dim: "2011-12-01"})) is None + + def test_create_seasonal_hourly_mean(self, xr_array_long, time_dim): + obj = object.__new__(ClimateFIRFilter) + xr_array_long = xr_array_long.resample({time_dim: "1H"}).interpolate() + res = obj.create_seasonal_hourly_mean(xr_array_long, time_dim) + assert len(set(res.dims).difference(xr_array_long.dims)) == 0 + assert res.coords[time_dim][0] < xr_array_long.coords[time_dim][0] + assert res.coords[time_dim][-1] > xr_array_long.coords[time_dim][-1] + + def test_create_seasonal_hourly_mean_sel_opts(self, xr_array_long, time_dim): + obj = object.__new__(ClimateFIRFilter) + xr_array_long = xr_array_long.resample({time_dim: "1H"}).interpolate() + sel_opts = {time_dim: slice("2010-05", "2010-08")} + res = obj.create_seasonal_hourly_mean(xr_array_long, time_dim, sel_opts=sel_opts) + assert res.dropna(time_dim)[f"{time_dim}.month"].min() == 5 + assert res.dropna(time_dim)[f"{time_dim}.month"].max() == 8 + + def test_create_unity_array(self, xr_array, time_dim): + obj = object.__new__(ClimateFIRFilter) + res = obj.create_monthly_unity_array(xr_array, time_dim) + assert np.datetime64("2008-12-16") == res.coords[time_dim][0].values + assert np.datetime64("2011-01-16") == res.coords[time_dim][-1].values + assert res.max() == res.min() + assert res.max() == 1 + assert res.shape == (26, 1) + res = obj.create_monthly_unity_array(xr_array, time_dim, extend_range=0) + assert res.shape == (1, 1) + assert np.datetime64("2010-01-16") == res.coords[time_dim][0].values + res = obj.create_monthly_unity_array(xr_array, time_dim, extend_range=28) + assert res.shape == (3, 1) + + def test_extend_apriori_at_end(self, xr_array_long, time_dim): + obj = object.__new__(ClimateFIRFilter) + apriori = xr.ones_like(xr_array_long).sel({time_dim: "2010"}) + res = obj.extend_apriori(xr_array_long, apriori, time_dim) + assert res.coords[time_dim][0] == apriori.coords[time_dim][0] + assert (res.coords[time_dim][-1] - xr_array_long.coords[time_dim][-1]) / np.timedelta64(1, "D") >= 365 + apriori = xr.ones_like(xr_array_long).sel({time_dim: slice("2010", "2011-08")}) + res = obj.extend_apriori(xr_array_long, apriori, time_dim) + assert (res.coords[time_dim][-1] - xr_array_long.coords[time_dim][-1]) / np.timedelta64(1, "D") >= (1.5 * 365) + + def test_extend_apriori_at_start(self, xr_array_long, time_dim): + obj = object.__new__(ClimateFIRFilter) + apriori = xr.ones_like(xr_array_long).sel({time_dim: "2011"}) + res = obj.extend_apriori(xr_array_long.sel({time_dim: slice("2010", "2010-10")}), apriori, time_dim) + assert (res.coords[time_dim][0] - apriori.coords[time_dim][0]) / np.timedelta64(1, "D") <= -365 * 2 + assert res.coords[time_dim][-1] == apriori.coords[time_dim][-1] + apriori = xr.ones_like(xr_array_long).sel({time_dim: slice("2010-02", "2011")}) + res = obj.extend_apriori(xr_array_long, apriori, time_dim) + assert (res.coords[time_dim][0] - apriori.coords[time_dim][0]) / np.timedelta64(1, "D") <= -365 + + def test_get_year_interval(self, xr_array, xr_array_long, time_dim): + obj = object.__new__(ClimateFIRFilter) + assert obj._get_year_interval(xr_array, time_dim) == (2010, 2010) + assert obj._get_year_interval(xr_array_long, time_dim) == (2010, 2011) + + def test_create_time_range_extend(self): + obj = object.__new__(ClimateFIRFilter) + res = obj._create_time_range_extend(1992, "1d", 10) + assert isinstance(res, slice) + assert res.start == np.datetime64("1991-12-21") + assert res.stop == np.datetime64("1993-01-11") + assert res.step is None + res = obj._create_time_range_extend(1992, "1H", 24) + assert isinstance(res, slice) + assert res.start == np.datetime64("1991-12-30T23:00:00") + assert res.stop == np.datetime64("1993-01-01T01:00:00") + assert res.step is None + + def test_properties(self): + obj = object.__new__(ClimateFIRFilter) + obj._h = [1, 2, 3] + obj._filtered = [4, 5, 63] + obj._apriori_list = [10, 11, 12, 13] + assert obj.filter_coefficients == [1, 2, 3] + assert obj.filtered_data == [4, 5, 63] + assert obj.apriori_data == [10, 11, 12, 13] + assert obj.initial_apriori_data == 10 + + def test_trim_data_to_minimum_length(self, xr_array, time_dim): + obj = object.__new__(ClimateFIRFilter) + xr_array = obj._shift_data(xr_array, range(-20, 1), time_dim, new_dim="window") + res = obj._trim_data_to_minimum_length(xr_array, 5, "window") + assert xr_array.shape == (21, 100, 1) + assert res.shape == (6, 100, 1) + res = obj._trim_data_to_minimum_length(xr_array, 5, "window", 10) + assert res.shape == (11, 100, 1) + res = obj._trim_data_to_minimum_length(xr_array, 30, "window") + assert res.shape == (21, 100, 1) + + def test_create_full_filter_result_array(self, xr_array, time_dim): + obj = object.__new__(ClimateFIRFilter) + xr_array_window = obj._shift_data(xr_array, range(-10, 1), time_dim, new_dim="window").dropna(time_dim) + res = obj._create_full_filter_result_array(xr_array, xr_array_window, "window") + assert res.dims == (*xr_array.dims, "window") + assert res.shape == (*xr_array.shape, 11) + res2 = obj._create_full_filter_result_array(res, xr_array_window, "window") + assert res.dims == res2.dims + assert res.shape == res2.shape + + def test_clim_filter(self, xr_array_long_with_var, time_dim, var_dim): + obj = object.__new__(ClimateFIRFilter) + filter_order = 10*24+1 + res = obj.clim_filter(xr_array_long_with_var, 24, 0.05, 10*24+1, sampling="1H", time_dim=time_dim, var_dim=var_dim) + assert len(res) == 4 + + # check filter data properties + assert res[0].shape == (*xr_array_long_with_var.shape, filter_order + 1) + assert res[0].dims == (*xr_array_long_with_var.dims, "window") + + # check filter properties + assert np.testing.assert_almost_equal( + res[1], obj._calculate_filter_coefficients("hamming", filter_order, 0.05, 24)) is None + + # check apriori + apriori = obj.create_monthly_mean(xr_array_long_with_var, time_dim, sampling="1H") + apriori = apriori.astype(xr_array_long_with_var.dtype) + apriori = obj.extend_apriori(xr_array_long_with_var, apriori, time_dim, "1H") + assert xr.testing.assert_equal(res[2], apriori) is None + + # check plot data format + assert isinstance(res[3], list) + assert isinstance(res[3][0], dict) + keys = {"t0", "var", "filter_input", "filter_input_nc", "valid_range", "time_range", "h", "new_dim"} + assert len(keys.symmetric_difference(res[3][0].keys())) == 0 + + def test_clim_filter_kwargs(self, xr_array_long_with_var, time_dim, var_dim): + obj = object.__new__(ClimateFIRFilter) + filter_order = 10 * 24 + 1 + apriori = obj.create_seasonal_hourly_mean(xr_array_long_with_var, time_dim, sampling="1H", as_anomaly=False) + apriori = apriori.astype(xr_array_long_with_var.dtype) + apriori = obj.extend_apriori(xr_array_long_with_var, apriori, time_dim, "1H") + plot_dates = [xr_array_long_with_var.coords[time_dim][1800].values] + res = obj.clim_filter(xr_array_long_with_var, 24, 0.05, 10 * 24 + 1, sampling="1H", time_dim=time_dim, + var_dim=var_dim, new_dim="total_new_dim", window=("kaiser", 5), minimum_length=1000, + apriori=apriori, plot_dates=plot_dates) + + assert res[0].shape == (*xr_array_long_with_var.shape, 1000 + 1) + assert res[0].dims == (*xr_array_long_with_var.dims, "total_new_dim") + assert np.testing.assert_almost_equal( + res[1], obj._calculate_filter_coefficients(("kaiser", 5), filter_order, 0.05, 24)) is None + assert xr.testing.assert_equal(res[2], apriori) is None + assert len(res[3]) == len(res[0].coords[var_dim]) + + +class TestFirFilterConvolve: + + def test_fir_filter_convolve(self): + data = np.cos(np.linspace(0, 4, num=100) * np.pi) + obj = object.__new__(ClimateFIRFilter) + h = obj._calculate_filter_coefficients("hamming", 21, 0.25, 1) + res = fir_filter_convolve(data, h) + assert res.shape == (100,) + assert np.testing.assert_almost_equal(np.dot(data[40:61], h) / sum(h), res[50]) is None + + +class TestFirwinKzf: + + def test_firwin_kzf(self): + res = firwin_kzf(3, 3) + assert np.testing.assert_almost_equal(res.sum(), 1) is None + assert res.shape == (7,) + assert np.testing.assert_array_equal(res * (3**3), np.array([1, 3, 6, 7, 6, 3, 1])) is None + + +class TestFilterWidthKzf: + + def test_filter_width_kzf(self): + assert filter_width_kzf(15, 5) == 71 + assert filter_width_kzf(3, 5) == 11 + + +class TestOmegaNullKzf: + + def test_omega_null_kzf(self): + assert np.testing.assert_almost_equal(omega_null_kzf(13, 3), 0.01986, decimal=5) is None + assert np.testing.assert_almost_equal(omega_null_kzf(105, 5), 0.00192, decimal=5) is None + assert np.testing.assert_almost_equal(omega_null_kzf(3, 5), 0.07103, decimal=5) is None + + def test_omega_null_kzf_alpha(self): + assert np.testing.assert_almost_equal(omega_null_kzf(3, 3, alpha=1), 0, decimal=1) is None + assert np.testing.assert_almost_equal(omega_null_kzf(3, 3, alpha=0), 0.25989, decimal=5) is None + assert np.testing.assert_almost_equal(omega_null_kzf(3, 3), omega_null_kzf(3, 3, alpha=0.5), decimal=5) is None + + + + + + diff --git a/test/test_helpers/test_testing_helpers.py b/test/test_helpers/test_testing_helpers.py index 385161c740f386847ef2f2dc4df17c1c84fa7fa5..bceed646c345d3add4602e67b55da1553eabdbaa 100644 --- a/test/test_helpers/test_testing_helpers.py +++ b/test/test_helpers/test_testing_helpers.py @@ -1,4 +1,4 @@ -from mlair.helpers.testing import PyTestRegex, PyTestAllEqual +from mlair.helpers.testing import PyTestRegex, PyTestAllEqual, test_nested_equality import re import xarray as xr @@ -11,7 +11,8 @@ class TestPyTestRegex: def test_init(self): test_regex = PyTestRegex(r"TestString\d+") - assert isinstance(test_regex._regex, re._pattern_type) + pattern = re._pattern_type if hasattr(re, "_pattern_type") else re.Pattern + assert isinstance(test_regex._regex, pattern) def test_eq(self): assert PyTestRegex(r"TestString\d*") == "TestString" @@ -46,3 +47,35 @@ class TestPyTestAllEqual: [xr.DataArray([1, 2, 3]), xr.DataArray([12, 22, 32])]]) assert PyTestAllEqual([["test", "test2"], ["test", "test2"]]) + + +class TestNestedEquality: + + def test_nested_equality_single_entries(self): + assert test_nested_equality(3, 3) is True + assert test_nested_equality(3.9, 3.9) is True + assert test_nested_equality(3.91, 3.9) is False + assert test_nested_equality("3", 3) is False + assert test_nested_equality("3", "3") is True + assert test_nested_equality(None, None) is True + + def test_nested_equality_xarray(self): + obj1 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20], 'y': [0, 10, 20]}) + obj2 = xr.ones_like(obj1) * obj1 + assert test_nested_equality(obj1, obj2) is True + + def test_nested_equality_numpy(self): + obj1 = np.random.randn(2, 3) + obj2 = obj1 * 1 + assert test_nested_equality(obj1, obj2) is True + + def test_nested_equality_list_tuple(self): + assert test_nested_equality([3, 3], [3, 3]) is True + assert test_nested_equality((2, 6), (2, 6)) is True + assert test_nested_equality([3, 3.5], [3.5, 3]) is False + assert test_nested_equality([3, 3.5, 10], [3, 3.5]) is False + + def test_nested_equality_dict(self): + assert test_nested_equality({"a": 3, "b": 10}, {"b": 10, "a": 3}) is True + assert test_nested_equality({"a": 3, "b": [10, 100]}, {"b": [10, 100], "a": 3}) is True + assert test_nested_equality({"a": 3, "b": 10, "c": "c"}, {"b": 10, "a": 3}) is False diff --git a/test/test_run_modules/test_training.py b/test/test_run_modules/test_training.py index 9d633a348bd1e24cd3f3abcdb83124f6107db2e9..1b83b3823519d63d5dcbc10f0e31fc3433f98f34 100644 --- a/test/test_run_modules/test_training.py +++ b/test/test_run_modules/test_training.py @@ -1,8 +1,12 @@ +import copy import glob import json +import time + import logging import os import shutil +from typing import Callable import tensorflow.keras as keras import mock @@ -11,6 +15,7 @@ from tensorflow.keras.callbacks import History from mlair.data_handler import DataCollection, KerasIterator, DefaultDataHandler from mlair.helpers import PyTestRegex +from mlair.model_modules.fully_connected_networks import FCN from mlair.model_modules.flatten import flatten_tail from mlair.model_modules.inception_model import InceptionModelBase from mlair.model_modules.keras_extensions import LearningRateDecay, HistoryAdvanced, CallbackHandler, EpoTimingCallback @@ -76,10 +81,24 @@ class TestTraining: obj.data_store.set("plot_path", path_plot, "general") obj._train_model = True obj._create_new_model = False - yield obj - if os.path.exists(path): - shutil.rmtree(path) - RunEnvironment().__del__() + try: + yield obj + finally: + if os.path.exists(path): + shutil.rmtree(path) + try: + RunEnvironment().__del__() + except AssertionError: + pass + # try: + # yield obj + # finally: + # if os.path.exists(path): + # shutil.rmtree(path) + # try: + # RunEnvironment().__del__() + # except AssertionError: + # pass @pytest.fixture def learning_rate(self): @@ -144,7 +163,7 @@ class TestTraining: @pytest.fixture def model(self, window_history_size, window_lead_time, statistics_per_var): channels = len(list(statistics_per_var.keys())) - return my_test_model(keras.layers.PReLU, window_history_size, channels, window_lead_time, 0.1, False) + return FCN([(window_history_size + 1, 1, channels)], [window_lead_time]) @pytest.fixture def callbacks(self, path): @@ -174,7 +193,8 @@ class TestTraining: obj.data_store.set("data_collection", data_collection, "general.train") obj.data_store.set("data_collection", data_collection, "general.val") obj.data_store.set("data_collection", data_collection, "general.test") - obj.model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error) + obj.model.compile(**obj.model.compile_options) + keras.utils.get_custom_objects().update(obj.model.custom_objects) return obj @pytest.fixture @@ -209,6 +229,57 @@ class TestTraining: if os.path.exists(path): shutil.rmtree(path) + @staticmethod + def create_training_obj(epochs, path, data_collection, batch_path, model_path, + statistics_per_var, window_history_size, window_lead_time) -> Training: + + channels = len(list(statistics_per_var.keys())) + model = FCN([(window_history_size + 1, 1, channels)], [window_lead_time]) + + obj = object.__new__(Training) + super(Training, obj).__init__() + obj.model = model + obj.train_set = None + obj.val_set = None + obj.test_set = None + obj.batch_size = 256 + obj.epochs = epochs + + clbk = CallbackHandler() + hist = HistoryAdvanced() + epo_timing = EpoTimingCallback() + clbk.add_callback(hist, os.path.join(path, "hist_checkpoint.pickle"), "hist") + lr = LearningRateDecay() + clbk.add_callback(lr, os.path.join(path, "lr_checkpoint.pickle"), "lr") + clbk.add_callback(epo_timing, os.path.join(path, "epo_timing.pickle"), "epo_timing") + clbk.create_model_checkpoint(filepath=os.path.join(path, "model_checkpoint"), monitor='val_loss', + save_best_only=True) + obj.callbacks = clbk + obj.lr_sc = lr + obj.hist = hist + obj.experiment_name = "TestExperiment" + obj.data_store.set("data_collection", data_collection, "general.train") + obj.data_store.set("data_collection", data_collection, "general.val") + obj.data_store.set("data_collection", data_collection, "general.test") + if not os.path.exists(path): + os.makedirs(path) + obj.data_store.set("experiment_path", path, "general") + os.makedirs(batch_path, exist_ok=True) + obj.data_store.set("batch_path", batch_path, "general") + os.makedirs(model_path, exist_ok=True) + obj.data_store.set("model_path", model_path, "general") + obj.data_store.set("model_name", os.path.join(model_path, "test_model.h5"), "general.model") + obj.data_store.set("experiment_name", "TestExperiment", "general") + + path_plot = os.path.join(path, "plots") + os.makedirs(path_plot, exist_ok=True) + obj.data_store.set("plot_path", path_plot, "general") + obj._train_model = True + obj._create_new_model = False + + obj.model.compile(**obj.model.compile_options) + return obj + def test_init(self, ready_to_init): assert isinstance(Training(), Training) # just test, if nothing fails @@ -223,9 +294,10 @@ class TestTraining: assert ready_to_run._run() is None # just test, if nothing fails def test_make_predict_function(self, init_without_run): - assert hasattr(init_without_run.model, "predict_function") is False + assert hasattr(init_without_run.model, "predict_function") is True + assert init_without_run.model.predict_function is None init_without_run.make_predict_function() - assert hasattr(init_without_run.model, "predict_function") + assert isinstance(init_without_run.model.predict_function, Callable) def test_set_gen(self, init_without_run): assert init_without_run.train_set is None @@ -234,7 +306,7 @@ class TestTraining: assert init_without_run.train_set._collection.return_value == "mock_train_gen" def test_set_generators(self, init_without_run): - sets = ["train", "val", "test"] + sets = ["train", "val"] assert all([getattr(init_without_run, f"{obj}_set") is None for obj in sets]) init_without_run.set_generators() assert not all([getattr(init_without_run, f"{obj}_set") is None for obj in sets]) @@ -242,10 +314,10 @@ class TestTraining: [getattr(init_without_run, f"{obj}_set")._collection.return_value == f"mock_{obj}_gen" for obj in sets]) def test_train(self, ready_to_train, path): - assert not hasattr(ready_to_train.model, "history") + assert ready_to_train.model.history is None assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 0 ready_to_train.train() - assert list(ready_to_train.model.history.history.keys()) == ["val_loss", "loss"] + assert sorted(list(ready_to_train.model.history.history.keys())) == ["loss", "val_loss"] assert ready_to_train.model.history.epoch == [0, 1] assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2 @@ -260,8 +332,8 @@ class TestTraining: def test_load_best_model_no_weights(self, init_without_run, caplog): caplog.set_level(logging.DEBUG) - init_without_run.load_best_model("notExisting") - assert caplog.record_tuples[0] == ("root", 10, PyTestRegex("load best model: notExisting")) + init_without_run.load_best_model("notExisting.h5") + assert caplog.record_tuples[0] == ("root", 10, PyTestRegex("load best model: notExisting.h5")) assert caplog.record_tuples[1] == ("root", 20, PyTestRegex("no weights to reload...")) def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, epo_timing, model_path): @@ -290,3 +362,14 @@ class TestTraining: history.model.metrics_names = mock.MagicMock(return_value=["loss", "mean_squared_error"]) init_without_run.create_monitoring_plots(history, learning_rate) assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2 + + def test_resume_training1(self, path: str, model_path, batch_path, data_collection, statistics_per_var, + window_history_size, window_lead_time): + + obj_1st = self.create_training_obj(4, path, data_collection, batch_path, model_path, statistics_per_var, + window_history_size, window_lead_time) + keras.utils.get_custom_objects().update(obj_1st.model.custom_objects) + assert obj_1st._run() is None + obj_2nd = self.create_training_obj(8, path, data_collection, batch_path, model_path, statistics_per_var, + window_history_size, window_lead_time) + assert obj_2nd._run() is None