Skip to content
Snippets Groups Projects
Commit 3a3cc762 authored by leufen1's avatar leufen1
Browse files

can create a hash from all important parameters, lazy loading works for all data handlers

parent b10bca2b
No related branches found
No related tags found
8 merge requests!319add all changes of dev into release v1.4.0 branch,!318Resolve "release v1.4.0",!283Merge latest develop into falcos issue,!279include Develop,!278Felix issue295 transformation parameters in data handler,!275include lazy preprocessing,!274Resolve "implement lazy data preprocessing",!259Draft: Resolve "WRF-Datahandler should inherit from SingleStationDatahandler"
Pipeline #63261 passed
......@@ -9,6 +9,7 @@ chardet==4.0.0
coverage==5.4
cycler==0.10.0
dask==2021.2.0
dill==0.3.3
fsspec==0.8.5
gast==0.4.0
grpcio==1.35.0
......
......@@ -9,6 +9,7 @@ chardet==4.0.0
coverage==5.4
cycler==0.10.0
dask==2021.2.0
dill==0.3.3
fsspec==0.8.5
gast==0.4.0
grpcio==1.35.0
......
......@@ -55,3 +55,6 @@ class AbstractDataHandler:
def get_coordinates(self) -> Union[None, Dict]:
"""Return coordinates as dictionary with keys `lon` and `lat`."""
return None
def _hash_list(self):
return []
......@@ -22,6 +22,7 @@ class DataHandlerKzFilterSingleStation(DataHandlerSingleStation):
"""Data handler for a single station to be used by a superior data handler. Inputs are kz filtered."""
_requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"])
_hash = DataHandlerSingleStation._hash + ["kz_filter_length", "kz_filter_iter", "filter_dim"]
DEFAULT_FILTER_DIM = "filter"
......
......@@ -12,6 +12,10 @@ import inspect
from typing import Callable
import datetime as dt
from typing import Any
import os
import dill
import logging
from functools import partial
import numpy as np
import pandas as pd
......@@ -77,6 +81,12 @@ class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation):
assert len(sampling) == 2
return list(map(lambda x: super(__class__, self).setup_data_path(data_path, x), sampling))
def _extract_lazy(self, lazy_data):
_data, self.meta, _input_data, _target_data = lazy_data
f_prep = partial(self._slice_prep, start=self.start, end=self.end)
self._data = f_prep(_data[0]), f_prep(_data[1])
self.input_data, self.target_data = list(map(f_prep, [_input_data, _target_data]))
class DataHandlerMixedSampling(DefaultDataHandler):
"""Data handler using mixed sampling for input and target."""
......@@ -119,14 +129,24 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi
new_date = dt.datetime.strptime(date, "%Y-%m-%d") + dt.timedelta(hours=delta)
return new_date.strftime("%Y-%m-%d")
def load_and_interpolate(self, ind) -> [xr.DataArray, pd.DataFrame]:
def update_start_end(self, ind):
if ind == 0: # for inputs
estimated_filter_width = self.estimate_filter_width()
start = self._add_time_delta(self.start, -estimated_filter_width)
end = self._add_time_delta(self.end, estimated_filter_width)
else: # target
start, end = self.start, self.end
return start, end
def load_and_interpolate(self, ind) -> [xr.DataArray, pd.DataFrame]:
start, end = self.update_start_end(ind)
# if ind == 0: # for inputs
# estimated_filter_width = self.estimate_filter_width()
# start = self._add_time_delta(self.start, -estimated_filter_width)
# end = self._add_time_delta(self.end, estimated_filter_width)
# else: # target
# start, end = self.start, self.end
vars = [self.variables, self.target_var]
stats_per_var = helpers.select_from_dict(self.statistics_per_var, vars[ind])
......@@ -138,6 +158,16 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi
limit=self.interpolation_limit[ind])
return data
def _create_lazy_data(self):
return [self._data, self.meta, self.input_data, self.target_data, self.cutoff_period, self.cutoff_period_days]
def _extract_lazy(self, lazy_data):
_data, self.meta, _input_data, _target_data, self.cutoff_period, self.cutoff_period_days = lazy_data
start_inp, end_inp = self.update_start_end(0)
self._data = list(map(self._slice_prep, _data, [start_inp, self.start], [end_inp, self.end]))
self.input_data = self._slice_prep(_input_data, start_inp, end_inp)
self.target_data = self._slice_prep(_target_data, self.start, self.end)
class DataHandlerMixedSamplingWithFilter(DefaultDataHandler):
"""Data handler using mixed sampling for input and target. Inputs are temporal filtered."""
......@@ -158,6 +188,7 @@ class DataHandlerSeparationOfScalesSingleStation(DataHandlerMixedSamplingWithFil
"""
_requirements = DataHandlerMixedSamplingWithFilterSingleStation.requirements()
_hash = DataHandlerMixedSamplingWithFilterSingleStation._hash + ["time_delta"]
def __init__(self, *args, time_delta=np.sqrt, **kwargs):
assert isinstance(time_delta, Callable)
......@@ -193,7 +224,7 @@ class DataHandlerSeparationOfScalesSingleStation(DataHandlerMixedSamplingWithFil
time_deltas = np.round(self.time_delta(self.cutoff_period)).astype(int)
start, end = window, 1
res = []
window_array = self.create_index_array(self.window_dim.range(start, end), squeeze_dim=self.target_dim)
window_array = self.create_index_array(self.window_dim, range(start, end), squeeze_dim=self.target_dim)
for delta, filter_name in zip(np.append(time_deltas, 1), data.coords["filter"]):
res_filter = []
data_filter = data.sel({"filter": filter_name})
......
......@@ -5,10 +5,11 @@ __date__ = '2020-07-20'
import copy
import datetime as dt
import dill
import hashlib
import logging
import os
from functools import reduce
from functools import reduce, partial
from typing import Union, List, Iterable, Tuple, Dict, Optional
import numpy as np
......@@ -46,6 +47,10 @@ class DataHandlerSingleStation(AbstractDataHandler):
DEFAULT_INTERPOLATION_LIMIT = 0
DEFAULT_INTERPOLATION_METHOD = "linear"
_hash = ["station", "statistics_per_var", "data_origin", "station_type", "network", "sampling", "target_dim",
"target_var", "time_dim", "iter_dim", "window_dim", "window_history_size", "window_history_offset",
"window_lead_time", "interpolation_limit", "interpolation_method"]
def __init__(self, station, data_path, statistics_per_var, station_type=DEFAULT_STATION_TYPE,
network=DEFAULT_NETWORK, sampling: Union[str, Tuple[str]] = DEFAULT_SAMPLING,
target_dim=DEFAULT_TARGET_DIM, target_var=DEFAULT_TARGET_VAR, time_dim=DEFAULT_TIME_DIM,
......@@ -101,7 +106,6 @@ class DataHandlerSingleStation(AbstractDataHandler):
self.observation = None
# create samples
# self.hash()
self.setup_samples()
def __str__(self):
......@@ -223,11 +227,41 @@ class DataHandlerSingleStation(AbstractDataHandler):
"""
Setup samples. This method prepares and creates samples X, and labels Y.
"""
if self.lazy is False:
self.make_input_target()
else:
self.load_lazy()
self.store_lazy()
if self.do_transformation is True:
self.call_transform()
self.make_samples()
def store_lazy(self):
hash = self._get_hash()
filename = os.path.join(self.lazy_path, hash + ".pickle")
if not os.path.exists(filename):
dill.dump(self._create_lazy_data(), file=open(filename, "wb"))
def _create_lazy_data(self):
return [self._data, self.meta, self.input_data, self.target_data]
def load_lazy(self):
hash = self._get_hash()
filename = os.path.join(self.lazy_path, hash + ".pickle")
try:
with open(filename, "rb") as pickle_file:
lazy_data = dill.load(pickle_file)
self._extract_lazy(lazy_data)
logging.info("<<<loaded lazy file")
except FileNotFoundError:
logging.info(">>>could not load lazy file")
self.make_input_target()
def _extract_lazy(self, lazy_data):
_data, self.meta, _input_data, _target_data = lazy_data
f_prep = partial(self._slice_prep, start=self.start, end=self.end)
self._data, self.input_data, self.target_data = list(map(f_prep, [_data, _input_data, _target_data]))
def make_input_target(self):
data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling,
self.station_type, self.network, self.store_data_locally, self.data_origin,
......@@ -669,16 +703,12 @@ class DataHandlerSingleStation(AbstractDataHandler):
return self.transform(data, dim=dim, opts=self._transformation[pos], inverse=inverse,
transformation_dim=self.target_dim)
def _hash_list(self):
return sorted(list(set(self._hash)))
def _get_hash(self):
hash_list = [self.station, self.statistics_per_var, self.data_origin, self.station_type, self.network,
self.sampling, self.target_dim, self.target_var, self.time_dim, self.iter_dim, self.window_dim,
self.window_history_size, self.window_history_offset, self.window_lead_time,
self.interpolation_limit, self.interpolation_method, self.min_length, self.start, self.end]
hash = "".join([str(e) for e in hash_list]).encode("utf-8")
m = hashlib.sha256()
m.update(hash)
return m.hexdigest()
hash = "".join([str(self.__getattribute__(e)) for e in self._hash_list()]).encode()
return hashlib.md5(hash).hexdigest()
if __name__ == "__main__":
......
......@@ -9,6 +9,7 @@ chardet==4.0.0
coverage==5.4
cycler==0.10.0
dask==2021.2.0
dill==0.3.3
fsspec==0.8.5
gast==0.4.0
grpcio==1.35.0
......
......@@ -9,6 +9,7 @@ chardet==4.0.0
coverage==5.4
cycler==0.10.0
dask==2021.2.0
dill==0.3.3
fsspec==0.8.5
gast==0.4.0
grpcio==1.35.0
......
......@@ -37,7 +37,7 @@ class TestDataHandlerMixedSamplingSingleStation:
req = object.__new__(DataHandlerSingleStation)
assert sorted(obj._requirements) == sorted(remove_items(req.requirements(), "station"))
@mock.patch("mlair.data_handler.data_handler_mixed_sampling.DataHandlerMixedSamplingSingleStation.setup_samples")
@mock.patch("mlair.data_handler.data_handler_single_station.DataHandlerSingleStation.setup_samples")
def test_init(self, mock_super_init):
obj = DataHandlerMixedSamplingSingleStation("first_arg", "second", {}, test=23, sampling="hourly",
interpolation_limit=(1, 10))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment