diff --git a/src/data_handling/data_generator.py b/src/data_handling/data_generator.py index f259c4030e1293bf565c87f65fe5eff3ee1a31b5..26b12d5955f2dd44661fe1da4450cb113c37b1b7 100644 --- a/src/data_handling/data_generator.py +++ b/src/data_handling/data_generator.py @@ -5,7 +5,7 @@ import keras from src import helpers from src.data_handling.data_preparation import DataPrep import os -from typing import Union, List, Tuple +from typing import Union, List, Tuple, Any import xarray as xr import pickle import logging @@ -93,16 +93,18 @@ class DataGenerator(keras.utils.Sequence): return data.history.transpose("datetime", "window", "Stations", "variables"), \ data.label.squeeze("Stations").transpose("datetime", "window") - def get_data_generator(self, key: Union[str, int] = None, load_tmp: bool = True) -> DataPrep: + def get_data_generator(self, key: Union[str, int] = None, local_tmp_storage: bool = True) -> DataPrep: """ Select data for given key, create a DataPrep object and interpolate, transform, make history and labels and remove nans. :param key: station key to choose the data generator. + :param local_tmp_storage: say if data should be processed from scratch or loaded as already processed data from + tmp pickle file to save computational time (but of course more disk space required). :return: preprocessed data as a DataPrep instance """ station = self.get_station_key(key) try: - if not load_tmp: + if not local_tmp_storage: raise FileNotFoundError data = self._load_pickle_data(station, self.variables) except FileNotFoundError: @@ -117,15 +119,26 @@ class DataGenerator(keras.utils.Sequence): self._save_pickle_data(data) return data - def _save_pickle_data(self, data): + def _save_pickle_data(self, data: Any): + """ + Save given data locally as .pickle in self.data_path_tmp with name '<station>_<var1>_<var2>_..._<varX>.pickle' + :param data: any data, that should be saved + """ file = os.path.join(self.data_path_tmp, f"{''.join(data.station)}_{'_'.join(sorted(data.variables))}.pickle") with open(file, "wb") as f: pickle.dump(data, f) logging.debug(f"save pickle data to {file}") - def _load_pickle_data(self, station, variables): + def _load_pickle_data(self, station: Union[str, List[str]], variables: List[str]) -> Any: + """ + Load locally saved data from self.data_path_tmp and name '<station>_<var1>_<var2>_..._<varX>.pickle'. + :param station: station to load + :param variables: list of variables to load + :return: loaded data + """ file = os.path.join(self.data_path_tmp, f"{''.join(station)}_{'_'.join(sorted(variables))}.pickle") - data = pickle.load(open(file, "rb")) + with open(file, "rb") as f: + data = pickle.load(f) logging.debug(f"load pickle data from {file}") return data diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py index 2f8b2777b8e7edd9dd2be3e425d3d2f18d81a4b2..5dc61738c37240326579e60d487ad4423302682e 100644 --- a/src/run_modules/pre_processing.py +++ b/src/run_modules/pre_processing.py @@ -119,7 +119,7 @@ class PreProcessing(RunEnvironment): t_inner.run() try: # (history, label) = data_gen[station] - data = data_gen.get_data_generator(key=station, load_tmp=load_tmp) + data = data_gen.get_data_generator(key=station, local_tmp_storage=load_tmp) valid_stations.append(station) logging.debug(f'{station}: history_shape = {data.history.transpose("datetime", "window", "Stations", "variables").shape}') logging.debug(f"{station}: loading time = {t_inner}")