Skip to content
Snippets Groups Projects
Commit 023ee744 authored by lukas leufen's avatar lukas leufen
Browse files

Merge branch 'lukas_issue292_feat_data-preprocessing' into 'develop'

Resolve "implement lazy data preprocessing"

See merge request !274
parents 7aefc11f 7cc89dec
Branches
Tags
7 merge requests!319add all changes of dev into release v1.4.0 branch,!318Resolve "release v1.4.0",!283Merge latest develop into falcos issue,!279include Develop,!278Felix issue295 transformation parameters in data handler,!274Resolve "implement lazy data preprocessing",!259Draft: Resolve "WRF-Datahandler should inherit from SingleStationDatahandler"
Pipeline #63364 passed
......@@ -2,6 +2,7 @@ absl-py==0.11.0
appdirs==1.4.4
astor==0.8.1
attrs==20.3.0
bottleneck==1.3.2
cached-property==1.5.2
certifi==2020.12.5
cftime==1.4.1
......@@ -9,6 +10,7 @@ chardet==4.0.0
coverage==5.4
cycler==0.10.0
dask==2021.2.0
dill==0.3.3
fsspec==0.8.5
gast==0.4.0
grpcio==1.35.0
......
......@@ -2,6 +2,7 @@ absl-py==0.11.0
appdirs==1.4.4
astor==0.8.1
attrs==20.3.0
bottleneck==1.3.2
cached-property==1.5.2
certifi==2020.12.5
cftime==1.4.1
......@@ -9,6 +10,7 @@ chardet==4.0.0
coverage==5.4
cycler==0.10.0
dask==2021.2.0
dill==0.3.3
fsspec==0.8.5
gast==0.4.0
grpcio==1.35.0
......
......@@ -55,3 +55,6 @@ class AbstractDataHandler:
def get_coordinates(self) -> Union[None, Dict]:
"""Return coordinates as dictionary with keys `lon` and `lat`."""
return None
def _hash_list(self):
return []
......@@ -8,6 +8,7 @@ import numpy as np
import pandas as pd
import xarray as xr
from typing import List, Union
from functools import partial
from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation
from mlair.data_handler import DefaultDataHandler
......@@ -22,6 +23,7 @@ class DataHandlerKzFilterSingleStation(DataHandlerSingleStation):
"""Data handler for a single station to be used by a superior data handler. Inputs are kz filtered."""
_requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"])
_hash = DataHandlerSingleStation._hash + ["kz_filter_length", "kz_filter_iter", "filter_dim"]
DEFAULT_FILTER_DIM = "filter"
......@@ -38,10 +40,7 @@ class DataHandlerKzFilterSingleStation(DataHandlerSingleStation):
def _check_sampling(self, **kwargs):
assert kwargs.get("sampling") == "hourly" # This data handler requires hourly data resolution
def setup_samples(self):
"""
Setup samples. This method prepares and creates samples X, and labels Y.
"""
def make_input_target(self):
data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling,
self.station_type, self.network, self.store_data_locally, self.data_origin)
self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method,
......@@ -54,9 +53,6 @@ class DataHandlerKzFilterSingleStation(DataHandlerSingleStation):
# import matplotlib.pyplot as plt
# self.input_data.sel(filter="74d", variables="temp", Stations="DEBW107").plot()
# self.input_data.sel(variables="temp", Stations="DEBW107").plot.line(hue="filter")
if self.do_transformation is True:
self.call_transform()
self.make_samples()
@TimeTrackingWrapper
def apply_kz_filter(self):
......@@ -88,6 +84,15 @@ class DataHandlerKzFilterSingleStation(DataHandlerSingleStation):
return self.history.transpose(self.time_dim, self.window_dim, self.iter_dim, self.target_dim,
self.filter_dim).copy()
def _create_lazy_data(self):
return [self._data, self.meta, self.input_data, self.target_data, self.cutoff_period, self.cutoff_period_days]
def _extract_lazy(self, lazy_data):
_data, self.meta, _input_data, _target_data, self.cutoff_period, self.cutoff_period_days = lazy_data
f_prep = partial(self._slice_prep, start=self.start, end=self.end)
self._data, self.input_data, self.target_data = list(map(f_prep, [_data, _input_data, _target_data]))
class DataHandlerKzFilter(DefaultDataHandler):
"""Data handler using kz filtered data."""
......
......@@ -12,6 +12,7 @@ import inspect
from typing import Callable
import datetime as dt
from typing import Any
from functools import partial
import numpy as np
import pandas as pd
......@@ -54,15 +55,9 @@ class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation):
assert len(parameter) == 2 # (inputs, targets)
kwargs.update({parameter_name: parameter})
def setup_samples(self):
"""
Setup samples. This method prepares and creates samples X, and labels Y.
"""
def make_input_target(self):
self._data = list(map(self.load_and_interpolate, [0, 1])) # load input (0) and target (1) data
self.set_inputs_and_targets()
if self.do_transformation is True:
self.call_transform()
self.make_samples()
def load_and_interpolate(self, ind) -> [xr.DataArray, pd.DataFrame]:
vars = [self.variables, self.target_var]
......@@ -83,6 +78,12 @@ class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation):
assert len(sampling) == 2
return list(map(lambda x: super(__class__, self).setup_data_path(data_path, x), sampling))
def _extract_lazy(self, lazy_data):
_data, self.meta, _input_data, _target_data = lazy_data
f_prep = partial(self._slice_prep, start=self.start, end=self.end)
self._data = f_prep(_data[0]), f_prep(_data[1])
self.input_data, self.target_data = list(map(f_prep, [_input_data, _target_data]))
class DataHandlerMixedSampling(DefaultDataHandler):
"""Data handler using mixed sampling for input and target."""
......@@ -104,19 +105,14 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi
def _check_sampling(self, **kwargs):
assert kwargs.get("sampling") == ("hourly", "daily")
def setup_samples(self):
def make_input_target(self):
"""
Setup samples. This method prepares and creates samples X, and labels Y.
A KZ filter is applied on the input data that has hourly resolution. Lables Y are provided as aggregated values
with daily resolution.
"""
self._data = list(map(self.load_and_interpolate, [0, 1])) # load input (0) and target (1) data
self.set_inputs_and_targets()
self.apply_kz_filter()
if self.do_transformation is True:
self.call_transform()
self.make_samples()
def estimate_filter_width(self):
"""
......@@ -130,14 +126,24 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi
new_date = dt.datetime.strptime(date, "%Y-%m-%d") + dt.timedelta(hours=delta)
return new_date.strftime("%Y-%m-%d")
def load_and_interpolate(self, ind) -> [xr.DataArray, pd.DataFrame]:
def update_start_end(self, ind):
if ind == 0: # for inputs
estimated_filter_width = self.estimate_filter_width()
start = self._add_time_delta(self.start, -estimated_filter_width)
end = self._add_time_delta(self.end, estimated_filter_width)
else: # target
start, end = self.start, self.end
return start, end
def load_and_interpolate(self, ind) -> [xr.DataArray, pd.DataFrame]:
start, end = self.update_start_end(ind)
# if ind == 0: # for inputs
# estimated_filter_width = self.estimate_filter_width()
# start = self._add_time_delta(self.start, -estimated_filter_width)
# end = self._add_time_delta(self.end, estimated_filter_width)
# else: # target
# start, end = self.start, self.end
vars = [self.variables, self.target_var]
stats_per_var = helpers.select_from_dict(self.statistics_per_var, vars[ind])
......@@ -149,6 +155,13 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi
limit=self.interpolation_limit[ind])
return data
def _extract_lazy(self, lazy_data):
_data, self.meta, _input_data, _target_data, self.cutoff_period, self.cutoff_period_days = lazy_data
start_inp, end_inp = self.update_start_end(0)
self._data = list(map(self._slice_prep, _data, [start_inp, self.start], [end_inp, self.end]))
self.input_data = self._slice_prep(_input_data, start_inp, end_inp)
self.target_data = self._slice_prep(_target_data, self.start, self.end)
class DataHandlerMixedSamplingWithFilter(DefaultDataHandler):
"""Data handler using mixed sampling for input and target. Inputs are temporal filtered."""
......@@ -169,6 +182,7 @@ class DataHandlerSeparationOfScalesSingleStation(DataHandlerMixedSamplingWithFil
"""
_requirements = DataHandlerMixedSamplingWithFilterSingleStation.requirements()
_hash = DataHandlerMixedSamplingWithFilterSingleStation._hash + ["time_delta"]
def __init__(self, *args, time_delta=np.sqrt, **kwargs):
assert isinstance(time_delta, Callable)
......@@ -204,7 +218,7 @@ class DataHandlerSeparationOfScalesSingleStation(DataHandlerMixedSamplingWithFil
time_deltas = np.round(self.time_delta(self.cutoff_period)).astype(int)
start, end = window, 1
res = []
window_array = self.create_index_array(self.window_dim.range(start, end), squeeze_dim=self.target_dim)
window_array = self.create_index_array(self.window_dim, range(start, end), squeeze_dim=self.target_dim)
for delta, filter_name in zip(np.append(time_deltas, 1), data.coords["filter"]):
res_filter = []
data_filter = data.sel({"filter": filter_name})
......@@ -212,7 +226,7 @@ class DataHandlerSeparationOfScalesSingleStation(DataHandlerMixedSamplingWithFil
res_filter.append(data_filter.shift({dim: -w * delta}))
res_filter = xr.concat(res_filter, dim=window_array).chunk()
res.append(res_filter)
res = xr.concat(res, dim="filter")
res = xr.concat(res, dim="filter").compute()
return res
def estimate_filter_width(self):
......
......@@ -5,9 +5,11 @@ __date__ = '2020-07-20'
import copy
import datetime as dt
import dill
import hashlib
import logging
import os
from functools import reduce
from functools import reduce, partial
from typing import Union, List, Iterable, Tuple, Dict, Optional
import numpy as np
......@@ -45,6 +47,10 @@ class DataHandlerSingleStation(AbstractDataHandler):
DEFAULT_INTERPOLATION_LIMIT = 0
DEFAULT_INTERPOLATION_METHOD = "linear"
_hash = ["station", "statistics_per_var", "data_origin", "station_type", "network", "sampling", "target_dim",
"target_var", "time_dim", "iter_dim", "window_dim", "window_history_size", "window_history_offset",
"window_lead_time", "interpolation_limit", "interpolation_method"]
def __init__(self, station, data_path, statistics_per_var, station_type=DEFAULT_STATION_TYPE,
network=DEFAULT_NETWORK, sampling: Union[str, Tuple[str]] = DEFAULT_SAMPLING,
target_dim=DEFAULT_TARGET_DIM, target_var=DEFAULT_TARGET_VAR, time_dim=DEFAULT_TIME_DIM,
......@@ -54,10 +60,16 @@ class DataHandlerSingleStation(AbstractDataHandler):
interpolation_limit: Union[int, Tuple[int]] = DEFAULT_INTERPOLATION_LIMIT,
interpolation_method: Union[str, Tuple[str]] = DEFAULT_INTERPOLATION_METHOD,
overwrite_local_data: bool = False, transformation=None, store_data_locally: bool = True,
min_length: int = 0, start=None, end=None, variables=None, data_origin: Dict = None, **kwargs):
min_length: int = 0, start=None, end=None, variables=None, data_origin: Dict = None,
lazy_preprocessing: bool = False, **kwargs):
super().__init__()
self.station = helpers.to_list(station)
self.path = self.setup_data_path(data_path, sampling)
self.lazy = lazy_preprocessing
self.lazy_path = None
if self.lazy is True:
self.lazy_path = os.path.join(data_path, "lazy_data", self.__class__.__name__)
check_path_and_create(self.lazy_path)
self.statistics_per_var = statistics_per_var
self.data_origin = data_origin
self.do_transformation = transformation is not None
......@@ -215,15 +227,46 @@ class DataHandlerSingleStation(AbstractDataHandler):
"""
Setup samples. This method prepares and creates samples X, and labels Y.
"""
if self.lazy is False:
self.make_input_target()
else:
self.load_lazy()
self.store_lazy()
if self.do_transformation is True:
self.call_transform()
self.make_samples()
def store_lazy(self):
hash = self._get_hash()
filename = os.path.join(self.lazy_path, hash + ".pickle")
if not os.path.exists(filename):
dill.dump(self._create_lazy_data(), file=open(filename, "wb"))
def _create_lazy_data(self):
return [self._data, self.meta, self.input_data, self.target_data]
def load_lazy(self):
hash = self._get_hash()
filename = os.path.join(self.lazy_path, hash + ".pickle")
try:
with open(filename, "rb") as pickle_file:
lazy_data = dill.load(pickle_file)
self._extract_lazy(lazy_data)
except FileNotFoundError:
self.make_input_target()
def _extract_lazy(self, lazy_data):
_data, self.meta, _input_data, _target_data = lazy_data
f_prep = partial(self._slice_prep, start=self.start, end=self.end)
self._data, self.input_data, self.target_data = list(map(f_prep, [_data, _input_data, _target_data]))
def make_input_target(self):
data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling,
self.station_type, self.network, self.store_data_locally, self.data_origin,
self.start, self.end)
self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method,
limit=self.interpolation_limit)
self.set_inputs_and_targets()
if self.do_transformation is True:
self.call_transform()
self.make_samples()
def set_inputs_and_targets(self):
inputs = self._data.sel({self.target_dim: helpers.to_list(self.variables)})
......@@ -658,6 +701,13 @@ class DataHandlerSingleStation(AbstractDataHandler):
return self.transform(data, dim=dim, opts=self._transformation[pos], inverse=inverse,
transformation_dim=self.target_dim)
def _hash_list(self):
return sorted(list(set(self._hash)))
def _get_hash(self):
hash = "".join([str(self.__getattribute__(e)) for e in self._hash_list()]).encode()
return hashlib.md5(hash).hexdigest()
if __name__ == "__main__":
# dp = AbstractDataPrep('data/', 'dummy', 'DEBW107', ['o3', 'temp'], statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})
......
......@@ -8,6 +8,7 @@ import gc
import logging
import os
import pickle
import dill
import shutil
from functools import reduce
from typing import Tuple, Union, List
......@@ -86,7 +87,7 @@ class DefaultDataHandler(AbstractDataHandler):
data = {"X": self._X, "Y": self._Y, "X_extreme": self._X_extreme, "Y_extreme": self._Y_extreme}
data = self._force_dask_computation(data)
with open(self._save_file, "wb") as f:
pickle.dump(data, f)
dill.dump(data, f)
logging.debug(f"save pickle data to {self._save_file}")
self._reset_data()
......@@ -101,7 +102,7 @@ class DefaultDataHandler(AbstractDataHandler):
def _load(self):
try:
with open(self._save_file, "rb") as f:
data = pickle.load(f)
data = dill.load(f)
logging.debug(f"load pickle data from {self._save_file}")
self._X, self._Y = data["X"], data["Y"]
self._X_extreme, self._Y_extreme = data["X_extreme"], data["Y_extreme"]
......
......@@ -9,6 +9,7 @@ import math
import os
import shutil
import pickle
import dill
from typing import Tuple, List
......@@ -109,7 +110,7 @@ class KerasIterator(keras.utils.Sequence):
"""Load pickle data from disk."""
file = self._path % index
with open(file, "rb") as f:
data = pickle.load(f)
data = dill.load(f)
return data["X"], data["Y"]
@staticmethod
......@@ -167,7 +168,7 @@ class KerasIterator(keras.utils.Sequence):
data = {"X": X, "Y": Y}
file = self._path % index
with open(file, "wb") as f:
pickle.dump(data, f)
dill.dump(data, f)
def _get_number_of_mini_batches(self, number_of_samples: int) -> int:
"""Return number of mini batches as the floored ration of number of samples to batch size."""
......
......@@ -11,8 +11,10 @@ import pandas as pd
from typing import Union, Tuple, Dict, List
from matplotlib import pyplot as plt
import itertools
import gc
import warnings
from mlair.helpers import to_list
from mlair.helpers import to_list, TimeTracking, TimeTrackingWrapper
Data = Union[xr.DataArray, pd.DataFrame]
......@@ -438,7 +440,7 @@ class SkillScores:
"""Calculate CASE IV."""
AI, BI, CI, data, suffix = self.skill_score_pre_calculations(internal_data, observation_name, forecast_name)
monthly_mean_external = self.create_monthly_mean_from_daily_data(external_data, index=data.index)
data = xr.concat([data, monthly_mean_external], dim="type")
data = xr.concat([data, monthly_mean_external], dim="type").dropna(dim="index")
mean, sigma = suffix["mean"], suffix["sigma"]
mean_external = monthly_mean_external.mean()
sigma_external = np.sqrt(monthly_mean_external.var())
......@@ -608,6 +610,48 @@ class KolmogorovZurbenkoFilterMovingWindow(KolmogorovZurbenkoBaseClass):
else:
return None
@TimeTrackingWrapper
def kz_filter_new(self, df, wl, itr):
"""
It passes the low frequency time series.
If filter method is from mean, max, min this method will call construct and rechunk before the actual
calculation to improve performance. If filter method is either median or percentile this approach is not
applicable and depending on the data and window size, this method can become slow.
Args:
wl(int): a window length
itr(int): a number of iteration
"""
warnings.filterwarnings("ignore")
df_itr = df.__deepcopy__()
try:
kwargs = {"min_periods": int(0.7 * wl),
"center": True,
self.filter_dim: wl}
for i in np.arange(0, itr):
print(i)
rolling = df_itr.chunk().rolling(**kwargs)
if self.method not in ["percentile", "median"]:
rolling = rolling.construct("construct").chunk("auto")
if self.method == "median":
df_mv_avg_tmp = rolling.median()
elif self.method == "percentile":
df_mv_avg_tmp = rolling.quantile(self.percentile)
elif self.method == "max":
df_mv_avg_tmp = rolling.max("construct")
elif self.method == "min":
df_mv_avg_tmp = rolling.min("construct")
else:
df_mv_avg_tmp = rolling.mean("construct")
df_itr = df_mv_avg_tmp.compute()
del df_mv_avg_tmp, rolling
gc.collect()
return df_itr
except ValueError:
raise ValueError
@TimeTrackingWrapper
def kz_filter(self, df, wl, itr):
"""
It passes the low frequency time series.
......@@ -616,15 +660,18 @@ class KolmogorovZurbenkoFilterMovingWindow(KolmogorovZurbenkoBaseClass):
wl(int): a window length
itr(int): a number of iteration
"""
import warnings
warnings.filterwarnings("ignore")
df_itr = df.__deepcopy__()
try:
kwargs = {"min_periods": 1,
kwargs = {"min_periods": int(0.7 * wl),
"center": True,
self.filter_dim: wl}
iter_vars = df_itr.coords["variables"].values
for var in iter_vars:
df_itr_var = df_itr.sel(variables=[var]).chunk()
df_itr_var = df_itr.sel(variables=[var])
for _ in np.arange(0, itr):
df_itr_var = df_itr_var.chunk()
rolling = df_itr_var.rolling(**kwargs)
if self.method == "median":
df_mv_avg_tmp = rolling.median()
......@@ -637,7 +684,7 @@ class KolmogorovZurbenkoFilterMovingWindow(KolmogorovZurbenkoBaseClass):
else:
df_mv_avg_tmp = rolling.mean()
df_itr_var = df_mv_avg_tmp.compute()
df_itr = df_itr.drop_sel(variables=var).combine_first(df_itr_var)
df_itr.loc[{"variables": [var]}] = df_itr_var
return df_itr
except ValueError:
raise ValueError
......@@ -2,6 +2,7 @@ absl-py==0.11.0
appdirs==1.4.4
astor==0.8.1
attrs==20.3.0
bottleneck==1.3.2
cached-property==1.5.2
certifi==2020.12.5
cftime==1.4.1
......@@ -9,6 +10,7 @@ chardet==4.0.0
coverage==5.4
cycler==0.10.0
dask==2021.2.0
dill==0.3.3
fsspec==0.8.5
gast==0.4.0
grpcio==1.35.0
......
......@@ -2,6 +2,7 @@ absl-py==0.11.0
appdirs==1.4.4
astor==0.8.1
attrs==20.3.0
bottleneck==1.3.2
cached-property==1.5.2
certifi==2020.12.5
cftime==1.4.1
......@@ -9,6 +10,7 @@ chardet==4.0.0
coverage==5.4
cycler==0.10.0
dask==2021.2.0
dill==0.3.3
fsspec==0.8.5
gast==0.4.0
grpcio==1.35.0
......
......@@ -37,7 +37,7 @@ class TestDataHandlerMixedSamplingSingleStation:
req = object.__new__(DataHandlerSingleStation)
assert sorted(obj._requirements) == sorted(remove_items(req.requirements(), "station"))
@mock.patch("mlair.data_handler.data_handler_mixed_sampling.DataHandlerMixedSamplingSingleStation.setup_samples")
@mock.patch("mlair.data_handler.data_handler_single_station.DataHandlerSingleStation.setup_samples")
def test_init(self, mock_super_init):
obj = DataHandlerMixedSamplingSingleStation("first_arg", "second", {}, test=23, sampling="hourly",
interpolation_limit=(1, 10))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment