Skip to content
Snippets Groups Projects
Commit 47c9e0d3 authored by lukas leufen's avatar lukas leufen
Browse files

do not calculate transformation if mean is given, switch to save local tmp data or not

parent 8d7ceb65
Branches
Tags
2 merge requests!50release for v0.7.0,!49Lukas issue054 feat transformation on entire dataset
Pipeline #30946 passed
......@@ -99,10 +99,12 @@ class DataGenerator(keras.utils.Sequence):
def setup_transformation(self, transformation):
if transformation is None:
return
transformation = transformation.copy()
scope = transformation.get("scope", "station")
method = transformation.get("method", "standardise")
mean = transformation.get("mean", None)
std = transformation.get("std", None)
if isinstance(mean, str):
if scope == "data":
if mean == "accurate":
mean, std = self.calculate_accurate_transformation(method)
......@@ -150,18 +152,20 @@ class DataGenerator(keras.utils.Sequence):
continue
return mean.mean("Stations") if mean.shape[1] > 0 else None, std.mean("Stations") if std.shape[1] > 0 else None
def get_data_generator(self, key: Union[str, int] = None, local_tmp_storage: bool = True) -> DataPrep:
def get_data_generator(self, key: Union[str, int] = None, load_local_tmp_storage: bool = True,
save_local_tmp_storage: bool = True) -> DataPrep:
"""
Select data for given key, create a DataPrep object and interpolate, transform, make history and labels and
remove nans.
:param key: station key to choose the data generator.
:param local_tmp_storage: say if data should be processed from scratch or loaded as already processed data from
tmp pickle file to save computational time (but of course more disk space required).
:param load_local_tmp_storage: say if data should be processed from scratch or loaded as already processed data
from tmp pickle file to save computational time (but of course more disk space required).
:param save_local_tmp_storage: save processed data as temporal file locally (default True)
:return: preprocessed data as a DataPrep instance
"""
station = self.get_station_key(key)
try:
if not local_tmp_storage:
if not load_local_tmp_storage:
raise FileNotFoundError
data = self._load_pickle_data(station, self.variables)
except FileNotFoundError:
......@@ -169,10 +173,12 @@ class DataGenerator(keras.utils.Sequence):
data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type,
**self.kwargs)
data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill)
if self.transformation is not None:
data.transform("datetime", **helpers.dict_pop(self.transformation, "scope"))
data.make_history_window(self.interpolate_dim, self.window_history_size)
data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time)
data.history_label_nan_remove(self.interpolate_dim)
if save_local_tmp_storage:
self._save_pickle_data(data)
return data
......
......@@ -79,6 +79,7 @@ class ExperimentSetup(RunEnvironment):
self._set_param("sampling", sampling)
self._set_param("transformation", transformation, default={"scope": "data", "method": "standardise",
"mean": "estimate"})
self._set_param("transformation", None, scope="general.preprocessing")
# target
self._set_param("target_var", target_var, default="o3")
......
......@@ -35,7 +35,7 @@ class PreProcessing(RunEnvironment):
def _run(self):
args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope="general.preprocessing")
kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope="general.preprocessing")
valid_stations = self.check_valid_stations(args, kwargs, self.data_store.get("stations", "general"), load_tmp=False)
valid_stations = self.check_valid_stations(args, kwargs, self.data_store.get("stations", "general"), load_tmp=False, save_tmp=False)
self.calculate_transformation(args, kwargs, valid_stations, load_tmp=False)
self.data_store.set("stations", valid_stations, "general")
self.split_train_val_test()
......@@ -94,14 +94,16 @@ class PreProcessing(RunEnvironment):
else:
set_stations = stations[index_list]
logging.debug(f"{set_name.capitalize()} stations (len={len(set_stations)}): {set_stations}")
set_stations = self.check_valid_stations(args, kwargs, set_stations)
set_stations = self.check_valid_stations(args, kwargs, set_stations, load_tmp=False)
self.data_store.set("stations", set_stations, scope)
set_args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope)
data_set = DataGenerator(**set_args, **kwargs)
self.data_store.set("generator", data_set, scope)
if set_name == "train":
self.data_store.set("transformation", data_set.transformation, "general")
@staticmethod
def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str], load_tmp=True):
def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str], load_tmp=True, save_tmp=True):
"""
Check if all given stations in `all_stations` are valid. Valid means, that there is data available for the given
time range (is included in `kwargs`). The shape and the loading time are logged in debug mode.
......@@ -123,7 +125,8 @@ class PreProcessing(RunEnvironment):
t_inner.run()
try:
# (history, label) = data_gen[station]
data = data_gen.get_data_generator(key=station, local_tmp_storage=load_tmp)
data = data_gen.get_data_generator(key=station, load_local_tmp_storage=load_tmp,
save_local_tmp_storage=save_tmp)
valid_stations.append(station)
logging.debug(f'{station}: history_shape = {data.history.transpose("datetime", "window", "Stations", "variables").shape}')
logging.debug(f"{station}: loading time = {t_inner}")
......
......@@ -104,7 +104,7 @@ class TestDataGenerator:
if os.path.exists(file):
os.remove(file)
assert not os.path.exists(file)
assert isinstance(gen.get_data_generator("DEBW107", local_tmp_storage=False), DataPrep)
assert isinstance(gen.get_data_generator("DEBW107", load_local_tmp_storage=False), DataPrep)
t = os.stat(file).st_ctime
assert os.path.exists(file)
assert isinstance(gen.get_data_generator("DEBW107"), DataPrep)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment