From df48442ff19bd836d3a3fe2c2d206285254a1a5f Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Mon, 17 Feb 2020 14:30:34 +0100
Subject: [PATCH] refac transpose of history and labe

---
 src/data_handling/data_generator.py   | 5 ++---
 src/data_handling/data_preparation.py | 5 ++++-
 2 files changed, 6 insertions(+), 4 deletions(-)

diff --git a/src/data_handling/data_generator.py b/src/data_handling/data_generator.py
index 19a94fbb..7aa24a88 100644
--- a/src/data_handling/data_generator.py
+++ b/src/data_handling/data_generator.py
@@ -78,8 +78,7 @@ class DataGenerator(keras.utils.Sequence):
             data = self.get_data_generator()
             self._iterator += 1
             if data.history is not None and data.label is not None:  # pragma: no branch
-                return data.history.transpose("datetime", "window", "Stations", "variables"), \
-                    data.label.squeeze("Stations").transpose("datetime", "window")
+                return data.get_transposed_history(), data.get_transposed_label()
             else:
                 self.__next__()  # pragma: no cover
         else:
@@ -92,7 +91,7 @@ class DataGenerator(keras.utils.Sequence):
         :return: The generator's time series of history data and its labels
         """
         data = self.get_data_generator(key=item)
-        return data.get_transposed_history(), data.label.squeeze("Stations").transpose("datetime", "window")
+        return data.get_transposed_history(), data.get_transposed_label()
 
     def get_data_generator(self, key: Union[str, int] = None, local_tmp_storage: bool = True) -> DataPrep:
         """
diff --git a/src/data_handling/data_preparation.py b/src/data_handling/data_preparation.py
index 5bca71f5..490515aa 100644
--- a/src/data_handling/data_preparation.py
+++ b/src/data_handling/data_preparation.py
@@ -388,7 +388,10 @@ class DataPrep(object):
 
     def get_transposed_history(self):
         if self.history is not None:
-            return self.history.transpose("datetime", "window", "Stations", "variables")
+            return self.history.transpose("datetime", "window", "Stations", "variables").copy()
+
+    def get_transposed_label(self):
+        return self.label.squeeze("Stations").transpose("datetime", "window").copy()
 
 
 if __name__ == "__main__":
-- 
GitLab