diff --git a/docs/_source/_plots/separation_of_scales.png b/docs/_source/_plots/separation_of_scales.png new file mode 100755 index 0000000000000000000000000000000000000000..d2bbc625a5d50051d8ec2babe976f88d7446e39e Binary files /dev/null and b/docs/_source/_plots/separation_of_scales.png differ diff --git a/mlair/configuration/defaults.py b/mlair/configuration/defaults.py index e30a6ec122f728dfb24be7252c94c73dc7347e58..ce42fc0eed6e891bc0a0625666da3dccfcc8a3ee 100644 --- a/mlair/configuration/defaults.py +++ b/mlair/configuration/defaults.py @@ -48,7 +48,7 @@ DEFAULT_CREATE_NEW_BOOTSTRAPS = False DEFAULT_NUMBER_OF_BOOTSTRAPS = 20 DEFAULT_PLOT_LIST = ["PlotMonthlySummary", "PlotStationMap", "PlotClimatologicalSkillScore", "PlotTimeSeries", "PlotCompetitiveSkillScore", "PlotBootstrapSkillScore", "PlotConditionalQuantiles", - "PlotAvailability"] + "PlotAvailability", "PlotSeparationOfScales"] DEFAULT_SAMPLING = "daily" DEFAULT_DATA_ORIGIN = {"cloudcover": "REA", "humidity": "REA", "pblheight": "REA", "press": "REA", "relhum": "REA", "temp": "REA", "totprecip": "REA", "u": "REA", "v": "REA", "no": "", "no2": "", "o3": "", diff --git a/mlair/data_handler/data_handler_kz_filter.py b/mlair/data_handler/data_handler_kz_filter.py index 965474c79dfc57456742fb5fa283f43eef081296..adc5ee0e72694baed6ec0ab0c0bf9259126af292 100644 --- a/mlair/data_handler/data_handler_kz_filter.py +++ b/mlair/data_handler/data_handler_kz_filter.py @@ -25,11 +25,9 @@ class DataHandlerKzFilterSingleStation(DataHandlerSingleStation): def __init__(self, *args, kz_filter_length, kz_filter_iter, **kwargs): self._check_sampling(**kwargs) - kz_filter_length = to_list(kz_filter_length) - kz_filter_iter = to_list(kz_filter_iter) # self.original_data = None # ToDo: implement here something to store unfiltered data - self.kz_filter_length = kz_filter_length - self.kz_filter_iter = kz_filter_iter + self.kz_filter_length = to_list(kz_filter_length) + self.kz_filter_iter = to_list(kz_filter_iter) self.cutoff_period = None self.cutoff_period_days = None super().__init__(*args, **kwargs) diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py index 65e19133d67722ef82c44bb2c8912bb2ac3a7350..aa1f0d55b55757875b640de00f66e62dd3586b11 100644 --- a/mlair/data_handler/data_handler_mixed_sampling.py +++ b/mlair/data_handler/data_handler_mixed_sampling.py @@ -4,15 +4,15 @@ __date__ = '2020-11-05' from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation from mlair.data_handler.data_handler_kz_filter import DataHandlerKzFilterSingleStation from mlair.data_handler import DefaultDataHandler -from mlair.configuration import path_config from mlair import helpers from mlair.helpers import remove_items from mlair.configuration.defaults import DEFAULT_SAMPLING -import logging -import os import inspect +from typing import Callable +import datetime as dt +import numpy as np import pandas as pd import xarray as xr @@ -37,7 +37,8 @@ class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation): def load_and_interpolate(self, ind) -> [xr.DataArray, pd.DataFrame]: data, self.meta = self.load_data(self.path[ind], self.station, self.statistics_per_var, self.sampling[ind], - self.station_type, self.network, self.store_data_locally, self.data_origin) + self.station_type, self.network, self.store_data_locally, self.data_origin, + self.start, self.end) data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, limit=self.interpolation_limit) return data @@ -88,6 +89,34 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi self.call_transform() self.make_samples() + def estimate_filter_width(self): + """ + f = 0.5 / (len * sqrt(itr)) -> T = 1 / f + :return: + """ + return int(self.kz_filter_length[0] * np.sqrt(self.kz_filter_iter[0]) * 2) + + @staticmethod + def _add_time_delta(date, delta): + new_date = dt.datetime.strptime(date, "%Y-%m-%d") + dt.timedelta(hours=delta) + return new_date.strftime("%Y-%m-%d") + + def load_and_interpolate(self, ind) -> [xr.DataArray, pd.DataFrame]: + + if ind == 0: # for inputs + estimated_filter_width = self.estimate_filter_width() + start = self._add_time_delta(self.start, -estimated_filter_width) + end = self._add_time_delta(self.end, estimated_filter_width) + else: # target + start, end = self.start, self.end + + data, self.meta = self.load_data(self.path[ind], self.station, self.statistics_per_var, self.sampling[ind], + self.station_type, self.network, self.store_data_locally, self.data_origin, + start, end) + data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, + limit=self.interpolation_limit) + return data + class DataHandlerMixedSamplingWithFilter(DefaultDataHandler): """Data handler using mixed sampling for input and target. Inputs are temporal filtered.""" @@ -95,3 +124,80 @@ class DataHandlerMixedSamplingWithFilter(DefaultDataHandler): data_handler = DataHandlerMixedSamplingWithFilterSingleStation data_handler_transformation = DataHandlerMixedSamplingWithFilterSingleStation _requirements = data_handler.requirements() + + +class DataHandlerMixedSamplingSeparationOfScalesSingleStation(DataHandlerMixedSamplingWithFilterSingleStation): + """ + Data handler using mixed sampling for input and target. Inputs are temporal filtered and depending on the + separation frequency of a filtered time series the time step delta for input data is adjusted (see image below). + + .. image:: ../../../../../_source/_plots/separation_of_scales.png + :width: 400 + + """ + + _requirements = DataHandlerMixedSamplingWithFilterSingleStation.requirements() + + def __init__(self, *args, time_delta=np.sqrt, **kwargs): + assert isinstance(time_delta, Callable) + self.time_delta = time_delta + super().__init__(*args, **kwargs) + + def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None: + """ + Create a xr.DataArray containing history data. + + Shift the data window+1 times and return a xarray which has a new dimension 'window' containing the shifted + data. This is used to represent history in the data. Results are stored in history attribute. + + :param dim_name_of_inputs: Name of dimension which contains the input variables + :param window: number of time steps to look back in history + Note: window will be treated as negative value. This should be in agreement with looking back on + a time line. Nonetheless positive values are allowed but they are converted to its negative + expression + :param dim_name_of_shift: Dimension along shift will be applied + """ + window = -abs(window) + data = self.input_data.data + self.history = self.stride(data, dim_name_of_shift, window) + + def stride(self, data: xr.DataArray, dim: str, window: int) -> xr.DataArray: + + # this is just a code snippet to check the results of the kz filter + # import matplotlib + # matplotlib.use("TkAgg") + # import matplotlib.pyplot as plt + # xr.concat(res, dim="filter").sel({"variables":"temp", "Stations":"DEBW107", "datetime":"2010-01-01T00:00:00"}).plot.line(hue="filter") + + time_deltas = np.round(self.time_delta(self.cutoff_period)).astype(int) + start, end = window, 1 + res = [] + window_array = self.create_index_array('window', range(start, end), squeeze_dim=self.target_dim) + for delta, filter_name in zip(np.append(time_deltas, 1), data.coords["filter"]): + res_filter = [] + data_filter = data.sel({"filter": filter_name}) + for w in range(start, end): + res_filter.append(data_filter.shift({dim: -w * delta})) + res_filter = xr.concat(res_filter, dim=window_array).chunk() + res.append(res_filter) + res = xr.concat(res, dim="filter") + return res + + def estimate_filter_width(self): + """ + Attention: this method returns the maximum value of + * either estimated filter width f = 0.5 / (len * sqrt(itr)) -> T = 1 / f or + * time delta method applied on the estimated filter width mupliplied by window_history_size + to provide a sufficiently wide filter width. + """ + est = self.kz_filter_length[0] * np.sqrt(self.kz_filter_iter[0]) * 2 + return int(max([self.time_delta(est) * self.window_history_size, est])) + + +class DataHandlerMixedSamplingSeparationOfScales(DefaultDataHandler): + """Data handler using mixed sampling for input and target. Inputs are temporal filtered and different time step + sizes are applied in relation to frequencies.""" + + data_handler = DataHandlerMixedSamplingSeparationOfScalesSingleStation + data_handler_transformation = DataHandlerMixedSamplingSeparationOfScalesSingleStation + _requirements = data_handler.requirements() diff --git a/mlair/data_handler/data_handler_single_station.py b/mlair/data_handler/data_handler_single_station.py index caf0d56186f3b495f406e740e135c1f38b3b5046..e554a3b32d8e4e2f5482a388374cfba87f7add15 100644 --- a/mlair/data_handler/data_handler_single_station.py +++ b/mlair/data_handler/data_handler_single_station.py @@ -143,7 +143,8 @@ class DataHandlerSingleStation(AbstractDataHandler): Setup samples. This method prepares and creates samples X, and labels Y. """ data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling, - self.station_type, self.network, self.store_data_locally, self.data_origin) + self.station_type, self.network, self.store_data_locally, self.data_origin, + self.start, self.end) self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, limit=self.interpolation_limit) self.set_inputs_and_targets() @@ -164,7 +165,7 @@ class DataHandlerSingleStation(AbstractDataHandler): self.remove_nan(self.time_dim) def load_data(self, path, station, statistics_per_var, sampling, station_type=None, network=None, - store_data_locally=False, data_origin: Dict = None): + store_data_locally=False, data_origin: Dict = None, start=None, end=None): """ Load data and meta data either from local disk (preferred) or download new data by using a custom download method. @@ -200,7 +201,7 @@ class DataHandlerSingleStation(AbstractDataHandler): store_data_locally=store_data_locally, data_origin=data_origin) logging.debug("loading finished") # create slices and check for negative concentration. - data = self._slice_prep(data) + data = self._slice_prep(data, start=start, end=end) data = self.check_for_negative_concentrations(data) return data, meta @@ -443,7 +444,7 @@ class DataHandlerSingleStation(AbstractDataHandler): self.label = self.label.sel({dim: intersect}) self.observation = self.observation.sel({dim: intersect}) - def _slice_prep(self, data: xr.DataArray, coord: str = 'datetime') -> xr.DataArray: + def _slice_prep(self, data: xr.DataArray, start=None, end=None) -> xr.DataArray: """ Set start and end date for slicing and execute self._slice(). @@ -452,9 +453,9 @@ class DataHandlerSingleStation(AbstractDataHandler): :return: sliced data """ - start = self.start if self.start is not None else data.coords[coord][0].values - end = self.end if self.end is not None else data.coords[coord][-1].values - return self._slice(data, start, end, coord) + start = start if start is not None else data.coords[self.time_dim][0].values + end = end if end is not None else data.coords[self.time_dim][-1].values + return self._slice(data, start, end, self.time_dim) @staticmethod def _slice(data: xr.DataArray, start: Union[date, str], end: Union[date, str], coord: str) -> xr.DataArray: diff --git a/mlair/helpers/join.py b/mlair/helpers/join.py index fc86beb375e4c763ab0721f1e5dcc89eaa27a605..43a0176811b54fba2983c1dba108f4c7977f1431 100644 --- a/mlair/helpers/join.py +++ b/mlair/helpers/join.py @@ -49,7 +49,7 @@ def download_join(station_name: Union[str, List[str]], stat_var: dict, station_t # correct stat_var values if data is not aggregated (hourly) if sampling == "hourly": - [stat_var.update({k: "values"}) for k in stat_var.keys()] + stat_var = {key: "values" for key in stat_var.keys()} # download all variables with given statistic data = None diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index c8682374e0d4c0d724d83a5e36977543ac3a50f8..caefbd82193b2202edb9e15d13ab34eeb143a97c 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -25,6 +25,11 @@ from mlair.helpers import TimeTrackingWrapper logging.getLogger('matplotlib').setLevel(logging.WARNING) +# import matplotlib +# matplotlib.use("TkAgg") +# import matplotlib.pyplot as plt + + class AbstractPlotClass: """ Abstract class for all plotting routines to unify plot workflow. @@ -72,6 +77,9 @@ class AbstractPlotClass: def __init__(self, plot_folder, plot_name, resolution=500): """Set up plot folder and name, and plot resolution (default 500dpi).""" + plot_folder = os.path.abspath(plot_folder) + if not os.path.exists(plot_folder): + os.makedirs(plot_folder) self.plot_folder = plot_folder self.plot_name = plot_name self.resolution = resolution @@ -82,7 +90,7 @@ class AbstractPlotClass: def _save(self, **kwargs): """Store plot locally. Name of and path to plot need to be set on initialisation.""" - plot_name = os.path.join(os.path.abspath(self.plot_folder), f"{self.plot_name}.pdf") + plot_name = os.path.join(self.plot_folder, f"{self.plot_name}.pdf") logging.debug(f"... save plot to {plot_name}") plt.savefig(plot_name, dpi=self.resolution, **kwargs) plt.close('all') @@ -995,10 +1003,31 @@ class PlotAvailability(AbstractPlotClass): return lgd +@TimeTrackingWrapper +class PlotSeparationOfScales(AbstractPlotClass): + + def __init__(self, collection: DataCollection, plot_folder: str = "."): + """Initialise.""" + # create standard Gantt plot for all stations (currently in single pdf file with single page) + plot_folder = os.path.join(plot_folder, "separation_of_scales") + super().__init__(plot_folder, "separation_of_scales") + self._plot(collection) + + def _plot(self, collection: DataCollection): + orig_plot_name = self.plot_name + for dh in collection: + data = dh.get_X(as_numpy=False)[0] + station = dh.id_class.station[0] + data = data.sel(Stations=station) + # plt.subplots() + data.plot(x="datetime", y="window", col="filter", row="variables", robust=True) + self.plot_name = f"{orig_plot_name}_{station}" + self._save() + + if __name__ == "__main__": stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'] path = "../../testrun_network/forecasts" plt_path = "../../" con_quan_cls = PlotConditionalQuantiles(stations, path, plt_path) - diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 571d3a07d15873af1c1ccedc59e0cc462e07820f..3dc91cbd54094f116f0d959fb9c845751e998464 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -19,7 +19,8 @@ from mlair.helpers import TimeTracking, statistics, extract_value from mlair.model_modules.linear_model import OrdinaryLeastSquaredModel from mlair.model_modules.model_class import AbstractModelClass from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, \ - PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotAvailability, PlotConditionalQuantiles + PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotAvailability, PlotConditionalQuantiles, \ + PlotSeparationOfScales from mlair.run_modules.run_environment import RunEnvironment @@ -262,7 +263,10 @@ class PostProcessing(RunEnvironment): plot_list = self.data_store.get("plot_list", "postprocessing") time_dimension = self.data_store.get("time_dim") - if self.bootstrap_skill_scores is not None and "PlotBootstrapSkillScore" in plot_list: + if ("filter" in self.test_data[0].get_X(as_numpy=False)[0].coords) and ("PlotSeparationOfScales" in plot_list): + PlotSeparationOfScales(self.test_data, plot_folder=self.plot_path) + + if (self.bootstrap_skill_scores is not None) and ("PlotBootstrapSkillScore" in plot_list): PlotBootstrapSkillScore(self.bootstrap_skill_scores, plot_folder=self.plot_path, model_setup="CNN") if "PlotConditionalQuantiles" in plot_list: diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py index 82af9cf02cda9401237bac15ccf0a52fa10acdad..4cee4a9744f33c86e8802aad27125cf0e0b30f3a 100644 --- a/mlair/run_modules/pre_processing.py +++ b/mlair/run_modules/pre_processing.py @@ -207,6 +207,7 @@ class PreProcessing(RunEnvironment): logging.info(f"check valid stations started{' (%s)' % (set_name if set_name is not None else 'all')}") # calculate transformation using train data if set_name == "train": + logging.info("setup transformation using train data exclusively") self.transformation(data_handler, set_stations) # start station check collection = DataCollection() diff --git a/run_mixed_sampling.py b/run_mixed_sampling.py index 56eef7f29872f9ab0ab995935a9008bfdfc6f930..5288063ac583e8dad24e253c5ae16810b540c5c8 100644 --- a/run_mixed_sampling.py +++ b/run_mixed_sampling.py @@ -4,17 +4,18 @@ __date__ = '2019-11-14' import argparse from mlair.workflows import DefaultWorkflow -from mlair.data_handler.data_handler_mixed_sampling import DataHandlerMixedSampling, DataHandlerMixedSamplingWithFilter +from mlair.data_handler.data_handler_mixed_sampling import DataHandlerMixedSampling, DataHandlerMixedSamplingWithFilter, \ + DataHandlerMixedSamplingSeparationOfScales def main(parser_args): args = dict(sampling="daily", sampling_inputs="hourly", - window_history_size=72, + window_history_size=24, **parser_args.__dict__, - data_handler=DataHandlerMixedSampling, # WithFilter, - kz_filter_length=[365 * 24, 20 * 24], - kz_filter_iter=[3, 5], + data_handler=DataHandlerMixedSamplingSeparationOfScales, + kz_filter_length=[100 * 24, 15 * 24], + kz_filter_iter=[4, 5], start="2006-01-01", train_start="2006-01-01", end="2011-12-31", diff --git a/test/test_configuration/test_defaults.py b/test/test_configuration/test_defaults.py index 7dc7199f2d8ed75af2d4f968a1f52ff3ee15baec..fffe7c84075eeeab37ebf59d52bc42dbf87bf522 100644 --- a/test/test_configuration/test_defaults.py +++ b/test/test_configuration/test_defaults.py @@ -70,4 +70,4 @@ class TestAllDefaults: assert DEFAULT_NUMBER_OF_BOOTSTRAPS == 20 assert DEFAULT_PLOT_LIST == ["PlotMonthlySummary", "PlotStationMap", "PlotClimatologicalSkillScore", "PlotTimeSeries", "PlotCompetitiveSkillScore", "PlotBootstrapSkillScore", - "PlotConditionalQuantiles", "PlotAvailability"] + "PlotConditionalQuantiles", "PlotAvailability", "PlotSeparationOfScales"] diff --git a/test/test_data_handler/test_data_handler.py b/test/test_data_handler/test_data_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..418c7946efe160c9bbfeccff9908a6cf17dec17f --- /dev/null +++ b/test/test_data_handler/test_data_handler.py @@ -0,0 +1,67 @@ +import pytest +import inspect + +from mlair.data_handler.abstract_data_handler import AbstractDataHandler + + +class TestDefaultDataHandler: + + def test_required_attributes(self): + dh = AbstractDataHandler + assert hasattr(dh, "_requirements") + assert hasattr(dh, "__init__") + assert hasattr(dh, "build") + assert hasattr(dh, "requirements") + assert hasattr(dh, "own_args") + assert hasattr(dh, "transformation") + assert hasattr(dh, "get_X") + assert hasattr(dh, "get_Y") + assert hasattr(dh, "get_data") + assert hasattr(dh, "get_coordinates") + + def test_init(self): + assert isinstance(AbstractDataHandler(), AbstractDataHandler) + + def test_build(self): + assert isinstance(AbstractDataHandler.build(), AbstractDataHandler) + + def test_requirements(self): + dh = AbstractDataHandler() + assert isinstance(dh._requirements, list) + assert len(dh._requirements) == 0 + assert isinstance(dh.requirements(), list) + assert len(dh.requirements()) == 0 + + def test_own_args(self): + dh = AbstractDataHandler() + assert isinstance(dh.own_args(), list) + assert len(dh.own_args()) == 0 + assert "self" not in dh.own_args() + + def test_transformation(self): + assert AbstractDataHandler.transformation() is None + + def test_get_X(self): + dh = AbstractDataHandler() + with pytest.raises(NotImplementedError): + dh.get_X() + assert sorted(["self", "upsampling", "as_numpy"]) == sorted(inspect.getfullargspec(dh.get_X).args) + assert (False, False) == inspect.getfullargspec(dh.get_X).defaults + + def test_get_Y(self): + dh = AbstractDataHandler() + with pytest.raises(NotImplementedError): + dh.get_Y() + assert sorted(["self", "upsampling", "as_numpy"]) == sorted(inspect.getfullargspec(dh.get_Y).args) + assert (False, False) == inspect.getfullargspec(dh.get_Y).defaults + + def test_get_data(self): + dh = AbstractDataHandler() + with pytest.raises(NotImplementedError): + dh.get_data() + assert sorted(["self", "upsampling", "as_numpy"]) == sorted(inspect.getfullargspec(dh.get_data).args) + assert (False, False) == inspect.getfullargspec(dh.get_data).defaults + + def test_get_coordinates(self): + dh = AbstractDataHandler() + assert dh.get_coordinates() is None