diff --git a/src/data_handling/data_generator.py b/src/data_handling/data_generator.py index 19a94fbb9dbbc8f382a225c852f34971a98395b8..7aa24a88a0b80b1d4d2b54973bf02f232184a732 100644 --- a/src/data_handling/data_generator.py +++ b/src/data_handling/data_generator.py @@ -78,8 +78,7 @@ class DataGenerator(keras.utils.Sequence): data = self.get_data_generator() self._iterator += 1 if data.history is not None and data.label is not None: # pragma: no branch - return data.history.transpose("datetime", "window", "Stations", "variables"), \ - data.label.squeeze("Stations").transpose("datetime", "window") + return data.get_transposed_history(), data.get_transposed_label() else: self.__next__() # pragma: no cover else: @@ -92,7 +91,7 @@ class DataGenerator(keras.utils.Sequence): :return: The generator's time series of history data and its labels """ data = self.get_data_generator(key=item) - return data.get_transposed_history(), data.label.squeeze("Stations").transpose("datetime", "window") + return data.get_transposed_history(), data.get_transposed_label() def get_data_generator(self, key: Union[str, int] = None, local_tmp_storage: bool = True) -> DataPrep: """ diff --git a/src/data_handling/data_preparation.py b/src/data_handling/data_preparation.py index 5bca71f52c9f136b5910d4e080491e0ff86484ae..490515aafaf51044ffb1121d276a3bdec4912fff 100644 --- a/src/data_handling/data_preparation.py +++ b/src/data_handling/data_preparation.py @@ -388,7 +388,10 @@ class DataPrep(object): def get_transposed_history(self): if self.history is not None: - return self.history.transpose("datetime", "window", "Stations", "variables") + return self.history.transpose("datetime", "window", "Stations", "variables").copy() + + def get_transposed_label(self): + return self.label.squeeze("Stations").transpose("datetime", "window").copy() if __name__ == "__main__":