Skip to content
Snippets Groups Projects
Commit a6dccb6d authored by lukas leufen's avatar lukas leufen
Browse files

first implementation of local tmp storage using pickle

parent d65b8b78
Branches
Tags
2 merge requests!37include new development,!33Lukas issue036 feat local temp data storage
Pipeline #29077 passed
......@@ -7,6 +7,8 @@ from src.data_handling.data_preparation import DataPrep
import os
from typing import Union, List, Tuple
import xarray as xr
import pickle
import logging
class DataGenerator(keras.utils.Sequence):
......@@ -23,6 +25,9 @@ class DataGenerator(keras.utils.Sequence):
interpolate_method: str = "linear", limit_nan_fill: int = 1, window_history_size: int = 7,
window_lead_time: int = 4, transform_method: str = "standardise", **kwargs):
self.data_path = os.path.abspath(data_path)
self.data_path_tmp = os.path.join(os.path.abspath(data_path), "tmp")
if not os.path.exists(self.data_path_tmp):
os.makedirs(self.data_path_tmp)
self.network = network
self.stations = helpers.to_list(stations)
self.variables = variables
......@@ -88,7 +93,7 @@ class DataGenerator(keras.utils.Sequence):
return data.history.transpose("datetime", "window", "Stations", "variables"), \
data.label.squeeze("Stations").transpose("datetime", "window")
def get_data_generator(self, key: Union[str, int] = None) -> DataPrep:
def get_data_generator(self, key: Union[str, int] = None, load_tmp: bool = True) -> DataPrep:
"""
Select data for given key, create a DataPrep object and interpolate, transform, make history and labels and
remove nans.
......@@ -96,6 +101,12 @@ class DataGenerator(keras.utils.Sequence):
:return: preprocessed data as a DataPrep instance
"""
station = self.get_station_key(key)
try:
if not load_tmp:
raise FileNotFoundError
data = self._load_pickle_data(station, self.variables)
except FileNotFoundError:
logging.info(f"load not pickle data for {station}")
data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type,
**self.kwargs)
data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill)
......@@ -103,6 +114,19 @@ class DataGenerator(keras.utils.Sequence):
data.make_history_window(self.interpolate_dim, self.window_history_size)
data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time)
data.history_label_nan_remove(self.interpolate_dim)
self._save_pickle_data(data)
return data
def _save_pickle_data(self, data):
file = os.path.join(self.data_path_tmp, f"{''.join(data.station)}_{'_'.join(sorted(data.variables))}.pickle")
with open(file, "wb") as f:
pickle.dump(data, f)
logging.debug(f"save pickle data to {file}")
def _load_pickle_data(self, station, variables):
file = os.path.join(self.data_path_tmp, f"{''.join(station)}_{'_'.join(sorted(variables))}.pickle")
data = pickle.load(open(file, "rb"))
logging.debug(f"load pickle data from {file}")
return data
def get_station_key(self, key: Union[None, str, int, List[Union[None, str, int]]]) -> str:
......
......@@ -36,7 +36,7 @@ class PreProcessing(RunEnvironment):
def _run(self):
args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope="general.preprocessing")
kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope="general.preprocessing")
valid_stations = self.check_valid_stations(args, kwargs, self.data_store.get("stations", "general"))
valid_stations = self.check_valid_stations(args, kwargs, self.data_store.get("stations", "general"), load_tmp=False)
self.data_store.set("stations", valid_stations, "general")
self.split_train_val_test()
self.report_pre_processing()
......@@ -97,7 +97,7 @@ class PreProcessing(RunEnvironment):
self.data_store.set("generator", data_set, scope)
@staticmethod
def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str]):
def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str], load_tmp=True):
"""
Check if all given stations in `all_stations` are valid. Valid means, that there is data available for the given
time range (is included in `kwargs`). The shape and the loading time are logged in debug mode.
......@@ -118,9 +118,10 @@ class PreProcessing(RunEnvironment):
for station in all_stations:
t_inner.run()
try:
(history, label) = data_gen[station]
# (history, label) = data_gen[station]
data = data_gen.get_data_generator(key=station, load_tmp=load_tmp)
valid_stations.append(station)
logging.debug(f"{station}: history_shape = {history.shape}")
logging.debug(f'{station}: history_shape = {data.history.transpose("datetime", "window", "Stations", "variables").shape}')
logging.debug(f"{station}: loading time = {t_inner}")
except (AttributeError, EmptyQueryResult):
continue
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment