Skip to content
Snippets Groups Projects
Commit df48442f authored by lukas leufen's avatar lukas leufen
Browse files

refac transpose of history and labe

parent 64e9ad15
Branches
Tags
2 merge requests!59Develop,!52implemented bootstraps
Pipeline #29758 passed
......@@ -78,8 +78,7 @@ class DataGenerator(keras.utils.Sequence):
data = self.get_data_generator()
self._iterator += 1
if data.history is not None and data.label is not None: # pragma: no branch
return data.history.transpose("datetime", "window", "Stations", "variables"), \
data.label.squeeze("Stations").transpose("datetime", "window")
return data.get_transposed_history(), data.get_transposed_label()
else:
self.__next__() # pragma: no cover
else:
......@@ -92,7 +91,7 @@ class DataGenerator(keras.utils.Sequence):
:return: The generator's time series of history data and its labels
"""
data = self.get_data_generator(key=item)
return data.get_transposed_history(), data.label.squeeze("Stations").transpose("datetime", "window")
return data.get_transposed_history(), data.get_transposed_label()
def get_data_generator(self, key: Union[str, int] = None, local_tmp_storage: bool = True) -> DataPrep:
"""
......
......@@ -388,7 +388,10 @@ class DataPrep(object):
def get_transposed_history(self):
if self.history is not None:
return self.history.transpose("datetime", "window", "Stations", "variables")
return self.history.transpose("datetime", "window", "Stations", "variables").copy()
def get_transposed_label(self):
return self.label.squeeze("Stations").transpose("datetime", "window").copy()
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment