From 5adcc794330cc28d4d0a28d262f061edab9129a1 Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Thu, 9 Jul 2020 11:38:03 +0200
Subject: [PATCH] DataPreparation accepts now classes instead of combination of
 id and class type

---
 src/data_handling/advanced_data_handling.py | 63 +++++++++++----------
 1 file changed, 33 insertions(+), 30 deletions(-)

diff --git a/src/data_handling/advanced_data_handling.py b/src/data_handling/advanced_data_handling.py
index 85877778..b1d20ec5 100644
--- a/src/data_handling/advanced_data_handling.py
+++ b/src/data_handling/advanced_data_handling.py
@@ -40,22 +40,25 @@ class DummyDataSingleStation:  # pragma: no cover
                                                                                   "window": range(5),
                                                                                   "variables": range(1)})
 
+    def __str__(self):
+        return self.name
+
 
 class DataPreparation:
 
-    def __init__(self, id, data_class, interpolate_dim: str, store_path, neighbor_ids=None, min_length=0,
+    def __init__(self, id_class, interpolate_dim: str, store_path, neighbors=None, min_length=0,
                  extreme_values: num_or_list = 1.,extremes_on_right_tail_only: bool = False,):
-        self.id = id
-        self.neighbor_ids = sorted(to_list(neighbor_ids)) if neighbor_ids is not None else []
+        self.id_class = id_class
+        self.neighbors = to_list(neighbors) if neighbors is not None else []
         self.interpolate_dim = interpolate_dim
         self.min_length = min_length
         self._X = None
         self._Y = None
         self._X_extreme = None
         self._Y_extreme = None
-        self._path = os.path.join(store_path, f"data_preparation_{self.id}.pickle")
+        self._save_file = os.path.join(store_path, f"data_preparation_{str(self.id_class)}.pickle")
         self._collection = []
-        self._create_collection(data_class)
+        self._create_collection()
         self.harmonise_X()
         self.multiply_extremes(extreme_values, extremes_on_right_tail_only, dim=self.interpolate_dim)
         self._store(fresh_store=True)
@@ -64,41 +67,40 @@ class DataPreparation:
         self._X, self._Y, self._X_extreme, self._Y_extreme = None, None, None, None
 
     def _cleanup(self):
-        directory = os.path.dirname(self._path)
+        directory = os.path.dirname(self._save_file)
         if os.path.exists(directory) is False:
             os.makedirs(directory)
-        if os.path.exists(self._path):
-            shutil.rmtree(self._path, ignore_errors=True)
+        if os.path.exists(self._save_file):
+            shutil.rmtree(self._save_file, ignore_errors=True)
 
     def _store(self, fresh_store=False):
         self._cleanup() if fresh_store is True else None
         data = {"X": self._X, "Y": self._Y, "X_extreme": self._X_extreme, "Y_extreme": self._Y_extreme}
-        with open(self._path, "wb") as f:
+        with open(self._save_file, "wb") as f:
             pickle.dump(data, f)
-        logging.debug(f"save pickle data to {self._path}")
+        logging.debug(f"save pickle data to {self._save_file}")
         self._reset_data()
 
     def _load(self):
         try:
-            with open(self._path, "rb") as f:
+            with open(self._save_file, "rb") as f:
                 data = pickle.load(f)
-            logging.debug(f"load pickle data from {self._path}")
+            logging.debug(f"load pickle data from {self._save_file}")
             self._X, self._Y = data["X"], data["Y"]
             self._X_extreme, self._Y_extreme = data["X_extreme"], data["Y_extreme"]
         except FileNotFoundError:
             pass
 
-    def get_data(self, upsampling=False):
+    def get_data(self, upsampling=False, as_numpy=True):
         self._load()
-        X = self.get_X(upsampling)
-        Y = self.get_Y(upsampling)
+        X = self.get_X(upsampling, as_numpy)
+        Y = self.get_Y(upsampling, as_numpy)
         self._reset_data()
         return X, Y
 
-    def _create_collection(self, data_class, **kwargs):
-        for name in [id] + self.neighbor_ids:
-            data = data_class(name, **kwargs)
-            self._collection.append(data)
+    def _create_collection(self):
+        for data_class in [self.id_class] + self.neighbors:
+            self._collection.append(data_class)
 
     def get_X_original(self):
         X = []
@@ -114,19 +116,19 @@ class DataPreparation:
     def _to_numpy(d):
         return list(map(lambda x: np.copy(x), d))
 
-    def get_X(self, upsamling=False):
+    def get_X(self, upsamling=False, as_numpy=True):
         no_data = (self._X is None)
         self._load() if no_data is True else None
         X = self._X if upsamling is False else self._X_extreme
         self._reset_data() if no_data is True else None
-        return self._to_numpy(X)
+        return self._to_numpy(X) if as_numpy is True else X
 
-    def get_Y(self, upsamling=False):
+    def get_Y(self, upsamling=False, as_numpy=True):
         no_data = (self._Y is None)
         self._load() if no_data is True else None
         Y = self._Y if upsamling is False else self._Y_extreme
         self._reset_data() if no_data is True else None
-        return self._to_numpy([Y])
+        return self._to_numpy([Y]) if as_numpy is True else Y
 
     def harmonise_X(self):
         X_original, Y_original = self.get_X_original(), self.get_Y_original()
@@ -160,7 +162,7 @@ class DataPreparation:
         """
         # check if X or Y is None
         if (self._X is None) or (self._Y is None):
-            logging.debug(f"{self.id} has no data for X or Y, skip multiply extremes")
+            logging.debug(f"{str(self.id_class)} has no data for X or Y, skip multiply extremes")
             return
 
         # check type if inputs
@@ -179,20 +181,20 @@ class DataPreparation:
                 X = self._X_extreme
                 Y = self._Y_extreme
 
-            # extract extremes based on occurance in labels
+            # extract extremes based on occurrence in labels
             other_dims = remove_items(list(Y.dims), dim)
             if extremes_on_right_tail_only:
-                extreme_Y_idx = (Y > extr_val).any(dim=other_dims)
+                extreme_idx = (Y > extr_val).any(dim=other_dims)
             else:
-                extreme_Y_idx = xr.concat([(Y < -extr_val).any(dim=other_dims[0]),
+                extreme_idx = xr.concat([(Y < -extr_val).any(dim=other_dims[0]),
                                            (Y > extr_val).any(dim=other_dims[0])],
                                           dim=other_dims[1]).any(dim=other_dims[1])
 
-            extremes_X = list(map(lambda x: x.sel(**{dim: extreme_Y_idx}), X))
+            extremes_X = list(map(lambda x: x.sel(**{dim: extreme_idx}), X))
             self._add_timedelta(extremes_X, dim, timedelta)
             # extremes_X = list(map(lambda x: x.coords[dim].values + np.timedelta64(*timedelta), extremes_X))
 
-            extremes_Y = Y.sel(**{dim: extreme_Y_idx})
+            extremes_Y = Y.sel(**{dim: extreme_idx})
             extremes_Y.coords[dim].values += np.timedelta64(*timedelta)
 
             self._Y_extreme = xr.concat([Y, extremes_Y], dim=dim)
@@ -212,6 +214,7 @@ if __name__ == "__main__":
     data.get_Y()
 
     path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
-    data_prep = DataPreparation("main_class", DummyDataSingleStation, "datetime", path, neighbor_ids=["neighbor1", "neighbor2"],
+    data_prep = DataPreparation(DummyDataSingleStation("main_class"), "datetime", path,
+                                neighbors=[DummyDataSingleStation("neighbor1"), DummyDataSingleStation("neighbor2")],
                                 extreme_values=[1., 1.2])
     data_prep.get_data(upsampling=False)
-- 
GitLab