Skip to content
Snippets Groups Projects
Commit bddb1866 authored by lukas leufen's avatar lukas leufen
Browse files

added slice, create index array

parent f76bf2ff
No related branches found
No related tags found
2 merge requests!6updated inception model and data prep class,!4data prep class
......@@ -4,11 +4,12 @@ __date__ = '2019-10-16'
import xarray as xr
import pandas as pd
import numpy as np
import logging
import os
from src import join, helpers
from src import statistics
from typing import Union, List, Dict
from typing import Union, List, Dict, Iterable
class DataPrep(object):
......@@ -20,7 +21,6 @@ class DataPrep(object):
self.variables = variables
self.mean = None
self.std = None
self.df = None
self.history = None
self.label = None
self.kwargs = kwargs
......@@ -187,18 +187,65 @@ class DataPrep(object):
def make_history_window(self, dim, window):
raise NotImplementedError
def shift(self, dim, window):
raise NotImplementedError
def shift(self, dim: str, window: int):
"""
This function uses xarray's shift function multiple times to represent history (if window <= 0)
or lead time (if window > 0)
:param dim: dimension along shift is applied
:param window: number of steps to shift (corresponds to the window length)
:return:
"""
start = 1
end = 1
if window <= 0:
start = window
else:
end = window + 1
res = []
for w in range(start, end):
res.append(self.data.shift({dim: -w}))
window_array = self.create_index_array('window', range(start, end))
res = xr.concat(res, dim=window_array)
return res
def make_labels(self, dimension_name_of_target, target_variable, dimension_name_of_shift, window):
raise NotImplementedError
def history_label_nan_remove(self, dim):
raise NotImplementedError
def history_label_nan_remove(self, dim: str) -> None:
"""
All NAs slices in dim which contain nans in self.history or self.label are removed in both data sets.
This is done to present only a full matrix to keras.fit.
:param dim:
:return:
"""
intersect = []
if (self.history is not None) and (self.label is not None):
non_nan_history = self.history.dropna(dim=dim)
non_nan_label = self.label.dropna(dim=dim)
intersect = np.intersect1d(non_nan_history.coords[dim].values,
non_nan_label.coords[dim].values)
if len(intersect) == 0:
self.history = None
self.label = None
else:
self.history = self.history.sel({dim: intersect})
self.label = self.label.sel({dim: intersect})
@staticmethod
def create_indexarray(index_name, index_values):
raise NotImplementedError
def create_index_array(index_name: str, index_value: Iterable[int]) -> xr.DataArray:
"""
This Function crates a 1D xarray.DataArray with given index name and value
:param index_name:
:param index_value:
:return:
"""
ind = pd.DataFrame({'val': index_value}, index=index_value)
res = xr.Dataset.from_dataframe(ind).to_array().rename({'index': index_name}).squeeze(dim='variable', drop=True)
res.name = index_name
return res
def _slice_prep(self, data, coord='datetime'):
raise NotImplementedError
......
......@@ -4,6 +4,8 @@ from src.data_preparation import DataPrep
import logging
import numpy as np
import xarray as xr
import datetime as dt
import pandas as pd
class TestDataPrep:
......@@ -138,3 +140,52 @@ class TestDataPrep:
data._transform_method = method
with pytest.raises(NotImplementedError):
data.inverse_transform()
def test_nan_remove_no_history(self, data):
assert data.history is None
assert data.label is None
data.history_label_nan_remove('datetime')
assert data.history is None
assert data.label is None
def test_nan_remove(self, data):
pass
def test_create_index_array(self, data):
index_array = data.create_index_array('window', range(1, 4))
assert np.testing.assert_array_equal(index_array.data, [1, 2, 3]) is None
assert index_array.name == 'window'
assert index_array.coords.dims == ('window', )
index_array = data.create_index_array('window', range(0, 1))
assert np.testing.assert_array_equal(index_array.data, [0]) is None
assert index_array.name == 'window'
assert index_array.coords.dims == ('window', )
@staticmethod
def extract_window_data(res, orig, w):
slice = {'variables': ['temp'], 'Stations': 'DEBW107', 'datetime': dt.datetime(1997, 1, 6)}
window = res.sel(slice).data.flatten()
if w <= 0:
delta = w
w = abs(w)+1
else:
delta = 1
slice = {'variables': ['temp'], 'Stations': 'DEBW107',
'datetime': pd.date_range(dt.date(1997, 1, 6) + dt.timedelta(days=delta), periods=w, freq='D')}
orig_slice = orig.sel(slice).data.flatten()
return window, orig_slice
def test_shift(self, data):
res = data.shift('datetime', 4)
window, orig = self.extract_window_data(res, data.data, 4)
assert res.coords.dims == ('window', 'Stations', 'datetime', 'variables')
assert list(res.data.shape) == [4] + list(data.data.shape)
assert np.testing.assert_array_equal(orig, window) is None
res = data.shift('datetime', -3)
window, orig = self.extract_window_data(res, data.data, -3)
assert list(res.data.shape) == [4] + list(data.data.shape)
assert np.testing.assert_array_equal(orig, window) is None
res = data.shift('datetime', 0)
window, orig = self.extract_window_data(res, data.data, 0)
assert list(res.data.shape) == [1] + list(data.data.shape)
assert np.testing.assert_array_equal(orig, window) is None
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment