From df48442ff19bd836d3a3fe2c2d206285254a1a5f Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Mon, 17 Feb 2020 14:30:34 +0100 Subject: [PATCH] refac transpose of history and labe --- src/data_handling/data_generator.py | 5 ++--- src/data_handling/data_preparation.py | 5 ++++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/data_handling/data_generator.py b/src/data_handling/data_generator.py index 19a94fbb..7aa24a88 100644 --- a/src/data_handling/data_generator.py +++ b/src/data_handling/data_generator.py @@ -78,8 +78,7 @@ class DataGenerator(keras.utils.Sequence): data = self.get_data_generator() self._iterator += 1 if data.history is not None and data.label is not None: # pragma: no branch - return data.history.transpose("datetime", "window", "Stations", "variables"), \ - data.label.squeeze("Stations").transpose("datetime", "window") + return data.get_transposed_history(), data.get_transposed_label() else: self.__next__() # pragma: no cover else: @@ -92,7 +91,7 @@ class DataGenerator(keras.utils.Sequence): :return: The generator's time series of history data and its labels """ data = self.get_data_generator(key=item) - return data.get_transposed_history(), data.label.squeeze("Stations").transpose("datetime", "window") + return data.get_transposed_history(), data.get_transposed_label() def get_data_generator(self, key: Union[str, int] = None, local_tmp_storage: bool = True) -> DataPrep: """ diff --git a/src/data_handling/data_preparation.py b/src/data_handling/data_preparation.py index 5bca71f5..490515aa 100644 --- a/src/data_handling/data_preparation.py +++ b/src/data_handling/data_preparation.py @@ -388,7 +388,10 @@ class DataPrep(object): def get_transposed_history(self): if self.history is not None: - return self.history.transpose("datetime", "window", "Stations", "variables") + return self.history.transpose("datetime", "window", "Stations", "variables").copy() + + def get_transposed_label(self): + return self.label.squeeze("Stations").transpose("datetime", "window").copy() if __name__ == "__main__": -- GitLab