Skip to content
Snippets Groups Projects
Commit aa18ad42 authored by lukas leufen's avatar lukas leufen
Browse files

added make_history_window and make_labels

parent bddb1866
Branches
Tags
2 merge requests!6updated inception model and data prep class,!4data prep class
...@@ -184,10 +184,21 @@ class DataPrep(object): ...@@ -184,10 +184,21 @@ class DataPrep(object):
else: else:
self.inverse_transform() self.inverse_transform()
def make_history_window(self, dim, window): def make_history_window(self, dim: str, window: int) -> None:
raise NotImplementedError """
This function uses shifts the data window+1 times and returns a xarray which has a new dimension 'window'
containing the shifted data. This is used to represent history in the data. Results are stored in self.history .
:param dim: Dimension along shift will be applied
:param window: number of time steps to look back in history
Note: window will be treated as negative value. This should be in agreement with looking back on
a time line. Nonetheless positive values are allowed but they are converted to its negative
expression
"""
window = -abs(window)
self.history = self.shift(dim, window)
def shift(self, dim: str, window: int): def shift(self, dim: str, window: int) -> xr.DataArray:
""" """
This function uses xarray's shift function multiple times to represent history (if window <= 0) This function uses xarray's shift function multiple times to represent history (if window <= 0)
or lead time (if window > 0) or lead time (if window > 0)
...@@ -208,8 +219,17 @@ class DataPrep(object): ...@@ -208,8 +219,17 @@ class DataPrep(object):
res = xr.concat(res, dim=window_array) res = xr.concat(res, dim=window_array)
return res return res
def make_labels(self, dimension_name_of_target, target_variable, dimension_name_of_shift, window): def make_labels(self, dim_name_of_target: str, target_var: str, dim_name_of_shift: str, window: int) -> None:
raise NotImplementedError """
This function creates a xarray.DataArray containing labels
:param dim_name_of_target: Name of dimension which contains the target variable
:param target_var: Name of target variable in 'dimension'
:param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied
:param window: lead time of label
"""
window = abs(window)
self.label = self.shift(dim_name_of_shift, window).sel({dim_name_of_target: target_var})
def history_label_nan_remove(self, dim: str) -> None: def history_label_nan_remove(self, dim: str) -> None:
""" """
......
...@@ -189,3 +189,20 @@ class TestDataPrep: ...@@ -189,3 +189,20 @@ class TestDataPrep:
window, orig = self.extract_window_data(res, data.data, 0) window, orig = self.extract_window_data(res, data.data, 0)
assert list(res.data.shape) == [1] + list(data.data.shape) assert list(res.data.shape) == [1] + list(data.data.shape)
assert np.testing.assert_array_equal(orig, window) is None assert np.testing.assert_array_equal(orig, window) is None
def test_make_history_window(self, data):
assert data.history is None
data.make_history_window('datetime', 5)
assert data.history is not None
save_history = data.history
data.make_history_window('datetime', -5)
assert np.testing.assert_array_equal(data.history, save_history) is None
def test_make_labels(self, data):
assert data.label is None
data.make_labels('variables', 'o3', 'datetime', 3)
assert data.label.variables.data == 'o3'
assert list(data.label.shape) == [3] + list(data.data.shape)[:2]
save_label = data.label
data.make_labels('variables', 'o3', 'datetime', -3)
assert np.testing.assert_array_equal(data.label, save_label) is None
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment