From b29ec5c361f225dca8804f0bad94aa296cbbe0ed Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Thu, 9 Jul 2020 14:00:42 +0200
Subject: [PATCH] rename get to capital X,Y in StationPrep, extreme_values is
 now None by default

---
 src/data_handling/advanced_data_handling.py | 49 +++++++++++++++++++--
 src/data_handling/data_preparation.py       |  7 ++-
 2 files changed, 51 insertions(+), 5 deletions(-)

diff --git a/src/data_handling/advanced_data_handling.py b/src/data_handling/advanced_data_handling.py
index b1d20ec5..3dcb78e0 100644
--- a/src/data_handling/advanced_data_handling.py
+++ b/src/data_handling/advanced_data_handling.py
@@ -47,7 +47,7 @@ class DummyDataSingleStation:  # pragma: no cover
 class DataPreparation:
 
     def __init__(self, id_class, interpolate_dim: str, store_path, neighbors=None, min_length=0,
-                 extreme_values: num_or_list = 1.,extremes_on_right_tail_only: bool = False,):
+                 extreme_values: num_or_list = None,extremes_on_right_tail_only: bool = False,):
         self.id_class = id_class
         self.neighbors = to_list(neighbors) if neighbors is not None else []
         self.interpolate_dim = interpolate_dim
@@ -102,6 +102,9 @@ class DataPreparation:
         for data_class in [self.id_class] + self.neighbors:
             self._collection.append(data_class)
 
+    def __repr__(self):
+        return ";".join(list(map(lambda x: str(x), self._collection)))
+
     def get_X_original(self):
         X = []
         for data in self._collection:
@@ -164,6 +167,10 @@ class DataPreparation:
         if (self._X is None) or (self._Y is None):
             logging.debug(f"{str(self.id_class)} has no data for X or Y, skip multiply extremes")
             return
+        if extreme_values is None:
+            logging.debug(f"No extreme values given, skip multiply extremes")
+            self._X_extreme, self._Y_extreme = self._X, self._Y
+            return
 
         # check type if inputs
         extreme_values = to_list(extreme_values)
@@ -206,8 +213,7 @@ class DataPreparation:
             d.coords[dim].values += np.timedelta64(*timedelta)
 
 
-
-if __name__ == "__main__":
+def run_data_prep():
 
     data = DummyDataSingleStation("main_class")
     data.get_X()
@@ -218,3 +224,40 @@ if __name__ == "__main__":
                                 neighbors=[DummyDataSingleStation("neighbor1"), DummyDataSingleStation("neighbor2")],
                                 extreme_values=[1., 1.2])
     data_prep.get_data(upsampling=False)
+
+
+def create_data_prep():
+
+    path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
+    station_type = None
+    network = 'UBA'
+    sampling = 'daily'
+    target_dim = 'variables'
+    target_var = 'o3'
+    interpolate_dim = 'datetime'
+    window_history_size = 7
+    window_lead_time = 3
+    central_station = StationPrep(path, "DEBW011", {'o3': 'dma8eu', 'temp': 'maximum'}, station_type, network, sampling, target_dim,
+                                  target_var, interpolate_dim, window_history_size, window_lead_time)
+    neighbor1 = StationPrep(path, "DEBW013", {'o3': 'dma8eu', 'temp-rea-miub': 'maximum'}, station_type, network, sampling, target_dim,
+                                  target_var, interpolate_dim, window_history_size, window_lead_time)
+    neighbor2 = StationPrep(path, "DEBW034", {'o3': 'dma8eu', 'temp': 'maximum'}, station_type, network, sampling, target_dim,
+                                  target_var, interpolate_dim, window_history_size, window_lead_time)
+
+    data_prep = []
+    data_prep.append(DataPreparation(central_station, interpolate_dim, path, neighbors=[neighbor1, neighbor2]))
+    data_prep.append(DataPreparation(neighbor1, interpolate_dim, path, neighbors=[central_station, neighbor2]))
+    data_prep.append(DataPreparation(neighbor2, interpolate_dim, path, neighbors=[neighbor1, central_station]))
+    return data_prep
+
+if __name__ == "__main__":
+    from src.data_handling.data_preparation import StationPrep
+    from src.data_handling.iterator import KerasIterator, DataCollection
+    data_prep = create_data_prep()
+    data_collection = DataCollection(data_prep)
+    for data in data_collection:
+        print(data)
+    path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata", "keras")
+    keras_it = KerasIterator(data_collection, 100, path)
+    keras_it[2]
+
diff --git a/src/data_handling/data_preparation.py b/src/data_handling/data_preparation.py
index 70722b1f..b49e4b90 100644
--- a/src/data_handling/data_preparation.py
+++ b/src/data_handling/data_preparation.py
@@ -42,6 +42,9 @@ class AbstractStationPrep():
         self.label = None
         self.observation = None
 
+    def __str__(self):
+        return self.station[0]
+
     def load_data(self):
         try:
             self.read_data_from_disk()
@@ -311,10 +314,10 @@ class StationPrep(AbstractStationPrep):
         """
         return self.label.squeeze("Stations").transpose("datetime", "window").copy()
 
-    def get_x(self):
+    def get_X(self):
         return self.get_transposed_history()
 
-    def get_y(self):
+    def get_Y(self):
         return self.get_transposed_label()
 
     def make_samples(self):
-- 
GitLab