diff --git a/src/data_preparation.py b/src/data_preparation.py index f087a59c7dec18b04cc4accf322e9a80bb724c2f..e45b3e8b30bcdbc6d1afe6fc8b488ed7b5edd8b3 100644 --- a/src/data_preparation.py +++ b/src/data_preparation.py @@ -4,11 +4,12 @@ __date__ = '2019-10-16' import xarray as xr import pandas as pd +import numpy as np import logging import os from src import join, helpers from src import statistics -from typing import Union, List, Dict +from typing import Union, List, Dict, Iterable class DataPrep(object): @@ -20,7 +21,6 @@ class DataPrep(object): self.variables = variables self.mean = None self.std = None - self.df = None self.history = None self.label = None self.kwargs = kwargs @@ -187,18 +187,65 @@ class DataPrep(object): def make_history_window(self, dim, window): raise NotImplementedError - def shift(self, dim, window): - raise NotImplementedError + def shift(self, dim: str, window: int): + """ + This function uses xarray's shift function multiple times to represent history (if window <= 0) + or lead time (if window > 0) + :param dim: dimension along shift is applied + :param window: number of steps to shift (corresponds to the window length) + :return: + """ + start = 1 + end = 1 + if window <= 0: + start = window + else: + end = window + 1 + res = [] + for w in range(start, end): + res.append(self.data.shift({dim: -w})) + window_array = self.create_index_array('window', range(start, end)) + res = xr.concat(res, dim=window_array) + return res def make_labels(self, dimension_name_of_target, target_variable, dimension_name_of_shift, window): raise NotImplementedError - def history_label_nan_remove(self, dim): - raise NotImplementedError + def history_label_nan_remove(self, dim: str) -> None: + """ + All NAs slices in dim which contain nans in self.history or self.label are removed in both data sets. + This is done to present only a full matrix to keras.fit. + + :param dim: + :return: + """ + intersect = [] + if (self.history is not None) and (self.label is not None): + non_nan_history = self.history.dropna(dim=dim) + non_nan_label = self.label.dropna(dim=dim) + intersect = np.intersect1d(non_nan_history.coords[dim].values, + non_nan_label.coords[dim].values) + + if len(intersect) == 0: + self.history = None + self.label = None + else: + self.history = self.history.sel({dim: intersect}) + self.label = self.label.sel({dim: intersect}) @staticmethod - def create_indexarray(index_name, index_values): - raise NotImplementedError + def create_index_array(index_name: str, index_value: Iterable[int]) -> xr.DataArray: + """ + This Function crates a 1D xarray.DataArray with given index name and value + + :param index_name: + :param index_value: + :return: + """ + ind = pd.DataFrame({'val': index_value}, index=index_value) + res = xr.Dataset.from_dataframe(ind).to_array().rename({'index': index_name}).squeeze(dim='variable', drop=True) + res.name = index_name + return res def _slice_prep(self, data, coord='datetime'): raise NotImplementedError diff --git a/test/test_data_preparation.py b/test/test_data_preparation.py index d228664e64b5383c906435f642a7561d510267cc..4102b6306512c6fed777591617f2bbcfa6caae34 100644 --- a/test/test_data_preparation.py +++ b/test/test_data_preparation.py @@ -4,6 +4,8 @@ from src.data_preparation import DataPrep import logging import numpy as np import xarray as xr +import datetime as dt +import pandas as pd class TestDataPrep: @@ -138,3 +140,52 @@ class TestDataPrep: data._transform_method = method with pytest.raises(NotImplementedError): data.inverse_transform() + + def test_nan_remove_no_history(self, data): + assert data.history is None + assert data.label is None + data.history_label_nan_remove('datetime') + assert data.history is None + assert data.label is None + + def test_nan_remove(self, data): + pass + + def test_create_index_array(self, data): + index_array = data.create_index_array('window', range(1, 4)) + assert np.testing.assert_array_equal(index_array.data, [1, 2, 3]) is None + assert index_array.name == 'window' + assert index_array.coords.dims == ('window', ) + index_array = data.create_index_array('window', range(0, 1)) + assert np.testing.assert_array_equal(index_array.data, [0]) is None + assert index_array.name == 'window' + assert index_array.coords.dims == ('window', ) + + @staticmethod + def extract_window_data(res, orig, w): + slice = {'variables': ['temp'], 'Stations': 'DEBW107', 'datetime': dt.datetime(1997, 1, 6)} + window = res.sel(slice).data.flatten() + if w <= 0: + delta = w + w = abs(w)+1 + else: + delta = 1 + slice = {'variables': ['temp'], 'Stations': 'DEBW107', + 'datetime': pd.date_range(dt.date(1997, 1, 6) + dt.timedelta(days=delta), periods=w, freq='D')} + orig_slice = orig.sel(slice).data.flatten() + return window, orig_slice + + def test_shift(self, data): + res = data.shift('datetime', 4) + window, orig = self.extract_window_data(res, data.data, 4) + assert res.coords.dims == ('window', 'Stations', 'datetime', 'variables') + assert list(res.data.shape) == [4] + list(data.data.shape) + assert np.testing.assert_array_equal(orig, window) is None + res = data.shift('datetime', -3) + window, orig = self.extract_window_data(res, data.data, -3) + assert list(res.data.shape) == [4] + list(data.data.shape) + assert np.testing.assert_array_equal(orig, window) is None + res = data.shift('datetime', 0) + window, orig = self.extract_window_data(res, data.data, 0) + assert list(res.data.shape) == [1] + list(data.data.shape) + assert np.testing.assert_array_equal(orig, window) is None