diff --git a/HPC_setup/requirements_HDFML_additionals.txt b/HPC_setup/requirements_HDFML_additionals.txt index 12e09ccdd620c0c81c78ae6d4781d4feb5b94baf..7d6163a6d676cd54588ccd2ab8fe85e1375e31c6 100644 --- a/HPC_setup/requirements_HDFML_additionals.txt +++ b/HPC_setup/requirements_HDFML_additionals.txt @@ -9,6 +9,7 @@ chardet==4.0.0 coverage==5.4 cycler==0.10.0 dask==2021.2.0 +dill==0.3.3 fsspec==0.8.5 gast==0.4.0 grpcio==1.35.0 diff --git a/HPC_setup/requirements_JUWELS_additionals.txt b/HPC_setup/requirements_JUWELS_additionals.txt index 12e09ccdd620c0c81c78ae6d4781d4feb5b94baf..7d6163a6d676cd54588ccd2ab8fe85e1375e31c6 100644 --- a/HPC_setup/requirements_JUWELS_additionals.txt +++ b/HPC_setup/requirements_JUWELS_additionals.txt @@ -9,6 +9,7 @@ chardet==4.0.0 coverage==5.4 cycler==0.10.0 dask==2021.2.0 +dill==0.3.3 fsspec==0.8.5 gast==0.4.0 grpcio==1.35.0 diff --git a/mlair/data_handler/abstract_data_handler.py b/mlair/data_handler/abstract_data_handler.py index f085d18bb8d33839a0e3b5f6f3d5ada92134e7f6..419db059a58beeb4ed7e3e198e41b565f8dc7d25 100644 --- a/mlair/data_handler/abstract_data_handler.py +++ b/mlair/data_handler/abstract_data_handler.py @@ -55,3 +55,6 @@ class AbstractDataHandler: def get_coordinates(self) -> Union[None, Dict]: """Return coordinates as dictionary with keys `lon` and `lat`.""" return None + + def _hash_list(self): + return [] diff --git a/mlair/data_handler/data_handler_kz_filter.py b/mlair/data_handler/data_handler_kz_filter.py index face8f3c400b702209c03fefd7818481a0fb2038..1ff1a36f61670383e1f3bf9314045e799be7394d 100644 --- a/mlair/data_handler/data_handler_kz_filter.py +++ b/mlair/data_handler/data_handler_kz_filter.py @@ -22,6 +22,7 @@ class DataHandlerKzFilterSingleStation(DataHandlerSingleStation): """Data handler for a single station to be used by a superior data handler. Inputs are kz filtered.""" _requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"]) + _hash = DataHandlerSingleStation._hash + ["kz_filter_length", "kz_filter_iter", "filter_dim"] DEFAULT_FILTER_DIM = "filter" diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py index ebcfbb4286f40ab2f8be2e1f8e46c7fa5ee45b14..acb62df9ce1fd9b61889c0b97973c953e9bda1ff 100644 --- a/mlair/data_handler/data_handler_mixed_sampling.py +++ b/mlair/data_handler/data_handler_mixed_sampling.py @@ -12,6 +12,10 @@ import inspect from typing import Callable import datetime as dt from typing import Any +import os +import dill +import logging +from functools import partial import numpy as np import pandas as pd @@ -77,6 +81,12 @@ class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation): assert len(sampling) == 2 return list(map(lambda x: super(__class__, self).setup_data_path(data_path, x), sampling)) + def _extract_lazy(self, lazy_data): + _data, self.meta, _input_data, _target_data = lazy_data + f_prep = partial(self._slice_prep, start=self.start, end=self.end) + self._data = f_prep(_data[0]), f_prep(_data[1]) + self.input_data, self.target_data = list(map(f_prep, [_input_data, _target_data])) + class DataHandlerMixedSampling(DefaultDataHandler): """Data handler using mixed sampling for input and target.""" @@ -119,14 +129,24 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi new_date = dt.datetime.strptime(date, "%Y-%m-%d") + dt.timedelta(hours=delta) return new_date.strftime("%Y-%m-%d") - def load_and_interpolate(self, ind) -> [xr.DataArray, pd.DataFrame]: - + def update_start_end(self, ind): if ind == 0: # for inputs estimated_filter_width = self.estimate_filter_width() start = self._add_time_delta(self.start, -estimated_filter_width) end = self._add_time_delta(self.end, estimated_filter_width) else: # target start, end = self.start, self.end + return start, end + + def load_and_interpolate(self, ind) -> [xr.DataArray, pd.DataFrame]: + + start, end = self.update_start_end(ind) + # if ind == 0: # for inputs + # estimated_filter_width = self.estimate_filter_width() + # start = self._add_time_delta(self.start, -estimated_filter_width) + # end = self._add_time_delta(self.end, estimated_filter_width) + # else: # target + # start, end = self.start, self.end vars = [self.variables, self.target_var] stats_per_var = helpers.select_from_dict(self.statistics_per_var, vars[ind]) @@ -138,6 +158,16 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi limit=self.interpolation_limit[ind]) return data + def _create_lazy_data(self): + return [self._data, self.meta, self.input_data, self.target_data, self.cutoff_period, self.cutoff_period_days] + + def _extract_lazy(self, lazy_data): + _data, self.meta, _input_data, _target_data, self.cutoff_period, self.cutoff_period_days = lazy_data + start_inp, end_inp = self.update_start_end(0) + self._data = list(map(self._slice_prep, _data, [start_inp, self.start], [end_inp, self.end])) + self.input_data = self._slice_prep(_input_data, start_inp, end_inp) + self.target_data = self._slice_prep(_target_data, self.start, self.end) + class DataHandlerMixedSamplingWithFilter(DefaultDataHandler): """Data handler using mixed sampling for input and target. Inputs are temporal filtered.""" @@ -158,6 +188,7 @@ class DataHandlerSeparationOfScalesSingleStation(DataHandlerMixedSamplingWithFil """ _requirements = DataHandlerMixedSamplingWithFilterSingleStation.requirements() + _hash = DataHandlerMixedSamplingWithFilterSingleStation._hash + ["time_delta"] def __init__(self, *args, time_delta=np.sqrt, **kwargs): assert isinstance(time_delta, Callable) @@ -193,7 +224,7 @@ class DataHandlerSeparationOfScalesSingleStation(DataHandlerMixedSamplingWithFil time_deltas = np.round(self.time_delta(self.cutoff_period)).astype(int) start, end = window, 1 res = [] - window_array = self.create_index_array(self.window_dim.range(start, end), squeeze_dim=self.target_dim) + window_array = self.create_index_array(self.window_dim, range(start, end), squeeze_dim=self.target_dim) for delta, filter_name in zip(np.append(time_deltas, 1), data.coords["filter"]): res_filter = [] data_filter = data.sel({"filter": filter_name}) diff --git a/mlair/data_handler/data_handler_single_station.py b/mlair/data_handler/data_handler_single_station.py index 820e601f25e1caa9b1860ed8f2f12efb1f0aa299..a8c6ea2e021251b72a9699426dc179d82fcb4e6d 100644 --- a/mlair/data_handler/data_handler_single_station.py +++ b/mlair/data_handler/data_handler_single_station.py @@ -5,10 +5,11 @@ __date__ = '2020-07-20' import copy import datetime as dt +import dill import hashlib import logging import os -from functools import reduce +from functools import reduce, partial from typing import Union, List, Iterable, Tuple, Dict, Optional import numpy as np @@ -46,6 +47,10 @@ class DataHandlerSingleStation(AbstractDataHandler): DEFAULT_INTERPOLATION_LIMIT = 0 DEFAULT_INTERPOLATION_METHOD = "linear" + _hash = ["station", "statistics_per_var", "data_origin", "station_type", "network", "sampling", "target_dim", + "target_var", "time_dim", "iter_dim", "window_dim", "window_history_size", "window_history_offset", + "window_lead_time", "interpolation_limit", "interpolation_method"] + def __init__(self, station, data_path, statistics_per_var, station_type=DEFAULT_STATION_TYPE, network=DEFAULT_NETWORK, sampling: Union[str, Tuple[str]] = DEFAULT_SAMPLING, target_dim=DEFAULT_TARGET_DIM, target_var=DEFAULT_TARGET_VAR, time_dim=DEFAULT_TIME_DIM, @@ -101,7 +106,6 @@ class DataHandlerSingleStation(AbstractDataHandler): self.observation = None # create samples - # self.hash() self.setup_samples() def __str__(self): @@ -223,11 +227,41 @@ class DataHandlerSingleStation(AbstractDataHandler): """ Setup samples. This method prepares and creates samples X, and labels Y. """ - self.make_input_target() + if self.lazy is False: + self.make_input_target() + else: + self.load_lazy() + self.store_lazy() if self.do_transformation is True: self.call_transform() self.make_samples() + def store_lazy(self): + hash = self._get_hash() + filename = os.path.join(self.lazy_path, hash + ".pickle") + if not os.path.exists(filename): + dill.dump(self._create_lazy_data(), file=open(filename, "wb")) + + def _create_lazy_data(self): + return [self._data, self.meta, self.input_data, self.target_data] + + def load_lazy(self): + hash = self._get_hash() + filename = os.path.join(self.lazy_path, hash + ".pickle") + try: + with open(filename, "rb") as pickle_file: + lazy_data = dill.load(pickle_file) + self._extract_lazy(lazy_data) + logging.info("<<<loaded lazy file") + except FileNotFoundError: + logging.info(">>>could not load lazy file") + self.make_input_target() + + def _extract_lazy(self, lazy_data): + _data, self.meta, _input_data, _target_data = lazy_data + f_prep = partial(self._slice_prep, start=self.start, end=self.end) + self._data, self.input_data, self.target_data = list(map(f_prep, [_data, _input_data, _target_data])) + def make_input_target(self): data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling, self.station_type, self.network, self.store_data_locally, self.data_origin, @@ -669,16 +703,12 @@ class DataHandlerSingleStation(AbstractDataHandler): return self.transform(data, dim=dim, opts=self._transformation[pos], inverse=inverse, transformation_dim=self.target_dim) + def _hash_list(self): + return sorted(list(set(self._hash))) + def _get_hash(self): - hash_list = [self.station, self.statistics_per_var, self.data_origin, self.station_type, self.network, - self.sampling, self.target_dim, self.target_var, self.time_dim, self.iter_dim, self.window_dim, - self.window_history_size, self.window_history_offset, self.window_lead_time, - self.interpolation_limit, self.interpolation_method, self.min_length, self.start, self.end] - - hash = "".join([str(e) for e in hash_list]).encode("utf-8") - m = hashlib.sha256() - m.update(hash) - return m.hexdigest() + hash = "".join([str(self.__getattribute__(e)) for e in self._hash_list()]).encode() + return hashlib.md5(hash).hexdigest() if __name__ == "__main__": diff --git a/requirements.txt b/requirements.txt index b0a6e7f59896fd0edf08977ee553c803f6c2e960..af742fdea75902515cfc180cdbe43f80cef25614 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,7 @@ chardet==4.0.0 coverage==5.4 cycler==0.10.0 dask==2021.2.0 +dill==0.3.3 fsspec==0.8.5 gast==0.4.0 grpcio==1.35.0 diff --git a/requirements_gpu.txt b/requirements_gpu.txt index 35fe0d5ee2a03f01737bc185d2a5bbaf26383806..7dd443a45df25a9e990888ab2ff061388ce36436 100644 --- a/requirements_gpu.txt +++ b/requirements_gpu.txt @@ -9,6 +9,7 @@ chardet==4.0.0 coverage==5.4 cycler==0.10.0 dask==2021.2.0 +dill==0.3.3 fsspec==0.8.5 gast==0.4.0 grpcio==1.35.0 diff --git a/test/test_data_handler/test_data_handler_mixed_sampling.py b/test/test_data_handler/test_data_handler_mixed_sampling.py index d2f9ce00224a61815c89e44b7c37a667d239b2f5..2a6553b7f495bb4eb8aeddf7c39f2f2517edc967 100644 --- a/test/test_data_handler/test_data_handler_mixed_sampling.py +++ b/test/test_data_handler/test_data_handler_mixed_sampling.py @@ -37,7 +37,7 @@ class TestDataHandlerMixedSamplingSingleStation: req = object.__new__(DataHandlerSingleStation) assert sorted(obj._requirements) == sorted(remove_items(req.requirements(), "station")) - @mock.patch("mlair.data_handler.data_handler_mixed_sampling.DataHandlerMixedSamplingSingleStation.setup_samples") + @mock.patch("mlair.data_handler.data_handler_single_station.DataHandlerSingleStation.setup_samples") def test_init(self, mock_super_init): obj = DataHandlerMixedSamplingSingleStation("first_arg", "second", {}, test=23, sampling="hourly", interpolation_limit=(1, 10))