Skip to content
Snippets Groups Projects
Commit a6dccb6d authored by lukas leufen's avatar lukas leufen
Browse files

first implementation of local tmp storage using pickle

parent d65b8b78
No related branches found
No related tags found
2 merge requests!37include new development,!33Lukas issue036 feat local temp data storage
Pipeline #29077 passed
......@@ -7,6 +7,8 @@ from src.data_handling.data_preparation import DataPrep
import os
from typing import Union, List, Tuple
import xarray as xr
import pickle
import logging
class DataGenerator(keras.utils.Sequence):
......@@ -23,6 +25,9 @@ class DataGenerator(keras.utils.Sequence):
interpolate_method: str = "linear", limit_nan_fill: int = 1, window_history_size: int = 7,
window_lead_time: int = 4, transform_method: str = "standardise", **kwargs):
self.data_path = os.path.abspath(data_path)
self.data_path_tmp = os.path.join(os.path.abspath(data_path), "tmp")
if not os.path.exists(self.data_path_tmp):
os.makedirs(self.data_path_tmp)
self.network = network
self.stations = helpers.to_list(stations)
self.variables = variables
......@@ -88,7 +93,7 @@ class DataGenerator(keras.utils.Sequence):
return data.history.transpose("datetime", "window", "Stations", "variables"), \
data.label.squeeze("Stations").transpose("datetime", "window")
def get_data_generator(self, key: Union[str, int] = None) -> DataPrep:
def get_data_generator(self, key: Union[str, int] = None, load_tmp: bool = True) -> DataPrep:
"""
Select data for given key, create a DataPrep object and interpolate, transform, make history and labels and
remove nans.
......@@ -96,6 +101,12 @@ class DataGenerator(keras.utils.Sequence):
:return: preprocessed data as a DataPrep instance
"""
station = self.get_station_key(key)
try:
if not load_tmp:
raise FileNotFoundError
data = self._load_pickle_data(station, self.variables)
except FileNotFoundError:
logging.info(f"load not pickle data for {station}")
data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type,
**self.kwargs)
data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill)
......@@ -103,6 +114,19 @@ class DataGenerator(keras.utils.Sequence):
data.make_history_window(self.interpolate_dim, self.window_history_size)
data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time)
data.history_label_nan_remove(self.interpolate_dim)
self._save_pickle_data(data)
return data
def _save_pickle_data(self, data):
file = os.path.join(self.data_path_tmp, f"{''.join(data.station)}_{'_'.join(sorted(data.variables))}.pickle")
with open(file, "wb") as f:
pickle.dump(data, f)
logging.debug(f"save pickle data to {file}")
def _load_pickle_data(self, station, variables):
file = os.path.join(self.data_path_tmp, f"{''.join(station)}_{'_'.join(sorted(variables))}.pickle")
data = pickle.load(open(file, "rb"))
logging.debug(f"load pickle data from {file}")
return data
def get_station_key(self, key: Union[None, str, int, List[Union[None, str, int]]]) -> str:
......
......@@ -36,7 +36,7 @@ class PreProcessing(RunEnvironment):
def _run(self):
args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope="general.preprocessing")
kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope="general.preprocessing")
valid_stations = self.check_valid_stations(args, kwargs, self.data_store.get("stations", "general"))
valid_stations = self.check_valid_stations(args, kwargs, self.data_store.get("stations", "general"), load_tmp=False)
self.data_store.set("stations", valid_stations, "general")
self.split_train_val_test()
self.report_pre_processing()
......@@ -97,7 +97,7 @@ class PreProcessing(RunEnvironment):
self.data_store.set("generator", data_set, scope)
@staticmethod
def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str]):
def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str], load_tmp=True):
"""
Check if all given stations in `all_stations` are valid. Valid means, that there is data available for the given
time range (is included in `kwargs`). The shape and the loading time are logged in debug mode.
......@@ -118,9 +118,10 @@ class PreProcessing(RunEnvironment):
for station in all_stations:
t_inner.run()
try:
(history, label) = data_gen[station]
# (history, label) = data_gen[station]
data = data_gen.get_data_generator(key=station, load_tmp=load_tmp)
valid_stations.append(station)
logging.debug(f"{station}: history_shape = {history.shape}")
logging.debug(f'{station}: history_shape = {data.history.transpose("datetime", "window", "Stations", "variables").shape}')
logging.debug(f"{station}: loading time = {t_inner}")
except (AttributeError, EmptyQueryResult):
continue
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment