diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py index d7ca5e3389c56b047358f02ba2f78bd0a7f6728f..0aae0ad6b212aff2fb914555f9aca83a48dddefb 100644 --- a/mlair/data_handler/data_handler_mixed_sampling.py +++ b/mlair/data_handler/data_handler_mixed_sampling.py @@ -12,7 +12,9 @@ from mlair.configuration.defaults import DEFAULT_SAMPLING import logging import os import inspect +from typing import Callable +import numpy as np import pandas as pd import xarray as xr @@ -95,3 +97,61 @@ class DataHandlerMixedSamplingWithFilter(DefaultDataHandler): data_handler = DataHandlerMixedSamplingWithFilterSingleStation data_handler_transformation = DataHandlerMixedSamplingWithFilterSingleStation _requirements = data_handler.requirements() + + +class DataHandlerMixedSamplingSeparationOfScalesSingleStation(DataHandlerMixedSamplingWithFilterSingleStation): + _requirements = DataHandlerMixedSamplingWithFilterSingleStation.requirements() + + def __init__(self, *args, time_delta=np.sqrt, **kwargs): + assert isinstance(time_delta, Callable) + self.time_delta = time_delta + super().__init__(*args, **kwargs) + + def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None: + """ + Create a xr.DataArray containing history data. + + Shift the data window+1 times and return a xarray which has a new dimension 'window' containing the shifted + data. This is used to represent history in the data. Results are stored in history attribute. + + :param dim_name_of_inputs: Name of dimension which contains the input variables + :param window: number of time steps to look back in history + Note: window will be treated as negative value. This should be in agreement with looking back on + a time line. Nonetheless positive values are allowed but they are converted to its negative + expression + :param dim_name_of_shift: Dimension along shift will be applied + """ + window = -abs(window) + data = self.input_data.data + self.history = self.stride(data, dim_name_of_shift, window) + + def stride(self, data: xr.DataArray, dim: str, window: int) -> xr.DataArray: + + # this is just a code snippet to check the results of the kz filter + # import matplotlib + # matplotlib.use("TkAgg") + # import matplotlib.pyplot as plt + # xr.concat(res, dim="filter").sel({"variables":"temp", "Stations":"DEBW107", "datetime":"2010-01-01T00:00:00"}).plot.line(hue="filter") + + time_deltas = np.round(self.time_delta(self.cutoff_period)).astype(int) + start, end = window, 1 + res = [] + window_array = self.create_index_array('window', range(start, end), squeeze_dim=self.target_dim) + for delta, filter_name in zip(np.append(time_deltas, 1), data.coords["filter"]): + res_filter = [] + data_filter = data.sel({"filter": filter_name}) + for w in range(start, end): + res_filter.append(data_filter.shift({dim: -w * delta})) + res_filter = xr.concat(res_filter, dim=window_array) + res.append(res_filter) + res = xr.concat(res, dim="filter").chunk() + return res + + +class DataHandlerMixedSamplingSeparationOfScales(DefaultDataHandler): + """Data handler using mixed sampling for input and target. Inputs are temporal filtered and different time step + sizes are applied in relation to frequencies.""" + + data_handler = DataHandlerMixedSamplingWithFilterSingleStation + data_handler_transformation = DataHandlerMixedSamplingWithFilterSingleStation + _requirements = data_handler.requirements()