Skip to content
Snippets Groups Projects
Commit 63324afa authored by lukas leufen's avatar lukas leufen
Browse files

temporary state, transformation is calculcated on each Init

parent 12aa46e6
Branches
Tags
2 merge requests!50release for v0.7.0,!49Lukas issue054 feat transformation on entire dataset
Pipeline #29592 passed
......@@ -2,7 +2,7 @@ __author__ = 'Felix Kleinert, Lukas Leufen'
__date__ = '2019-11-07'
import os
from typing import Union, List, Tuple, Any
from typing import Union, List, Tuple, Any, Dict
import keras
import xarray as xr
......@@ -11,6 +11,7 @@ import logging
from src import helpers
from src.data_handling.data_preparation import DataPrep
from src.join import EmptyQueryResult
class DataGenerator(keras.utils.Sequence):
......@@ -25,7 +26,7 @@ class DataGenerator(keras.utils.Sequence):
def __init__(self, data_path: str, network: str, stations: Union[str, List[str]], variables: List[str],
interpolate_dim: str, target_dim: str, target_var: str, station_type: str = None,
interpolate_method: str = "linear", limit_nan_fill: int = 1, window_history_size: int = 7,
window_lead_time: int = 4, transform_method: str = "standardise", **kwargs):
window_lead_time: int = 4, transformation: Dict = None, **kwargs):
self.data_path = os.path.abspath(data_path)
self.data_path_tmp = os.path.join(os.path.abspath(data_path), "tmp")
if not os.path.exists(self.data_path_tmp):
......@@ -41,8 +42,8 @@ class DataGenerator(keras.utils.Sequence):
self.limit_nan_fill = limit_nan_fill
self.window_history_size = window_history_size
self.window_lead_time = window_lead_time
self.transform_method = transform_method
self.kwargs = kwargs
self.transformation = self.setup_transformation(transformation)
def __repr__(self):
"""
......@@ -94,6 +95,44 @@ class DataGenerator(keras.utils.Sequence):
data = self.get_data_generator(key=item)
return data.get_transposed_history(), data.label.squeeze("Stations").transpose("datetime", "window")
def setup_transformation(self, transformation):
if transformation is None:
return
scope = transformation.get("scope", "station")
method = transformation.get("method", "standardise")
mean = transformation.get("mean", None)
std = transformation.get("std", None)
if scope == "data":
if mean == "accurate":
mean, std = self.calculate_accurate_transformation(method)
elif mean == "estimate":
mean, std = self.calculate_estimated_transformation(method)
else:
mean = mean
transformation["mean"] = mean
transformation["std"] = std
return transformation
def calculate_accurate_transformation(self, method):
mean = None
std = None
return mean, std
def calculate_estimated_transformation(self, method):
mean = xr.DataArray([[]]*len(self.variables),coords={"variables": self.variables, "Stations": range(0)}, dims=["variables", "Stations"])
std = xr.DataArray([[]]*len(self.variables),coords={"variables": self.variables, "Stations": range(0)}, dims=["variables", "Stations"])
for station in self.stations:
try:
data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type,
**self.kwargs)
data.transform("datetime", method=method)
mean = mean.combine_first(data.mean)
std = std.combine_first(data.std)
data.transform("datetime", method=method, inverse=True)
except EmptyQueryResult:
continue
return mean.mean("Stations") if mean.shape[1] > 0 else "hi", std.mean("Stations") if std.shape[1] > 0 else None
def get_data_generator(self, key: Union[str, int] = None, local_tmp_storage: bool = True) -> DataPrep:
"""
Select data for given key, create a DataPrep object and interpolate, transform, make history and labels and
......@@ -113,7 +152,7 @@ class DataGenerator(keras.utils.Sequence):
data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type,
**self.kwargs)
data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill)
data.transform("datetime", method=self.transform_method)
data.transform("datetime", **helpers.dict_pop(self.transformation, "scope"))
data.make_history_window(self.interpolate_dim, self.window_history_size)
data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time)
data.history_label_nan_remove(self.interpolate_dim)
......
......@@ -190,3 +190,8 @@ def float_round(number: float, decimals: int = 0, round_type: Callable = math.ce
"""
multiplier = 10. ** decimals
return round_type(number * multiplier) / multiplier
def dict_pop(dict: Dict, pop_keys):
pop_keys = to_list(pop_keys)
return {k: v for k, v in dict.items() if k not in pop_keys}
......@@ -33,7 +33,7 @@ class ExperimentSetup(RunEnvironment):
limit_nan_fill=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None,
test_end=None, use_all_stations_on_all_data_sets=True, trainable=None, fraction_of_train=None,
experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily",
create_new_model=None):
create_new_model=None, transformation=None):
# create run framework
super().__init__()
......@@ -77,6 +77,8 @@ class ExperimentSetup(RunEnvironment):
self._set_param("window_history_size", window_history_size, default=13)
self._set_param("overwrite_local_data", overwrite_local_data, default=False, scope="general.preprocessing")
self._set_param("sampling", sampling)
self._set_param("transformation", transformation, default={"scope": "data", "method": "standardise",
"mean": "estimate"})
# target
self._set_param("target_var", target_var, default="o3")
......
......@@ -12,7 +12,7 @@ from src.run_modules.run_environment import RunEnvironment
DEFAULT_ARGS_LIST = ["data_path", "network", "stations", "variables", "interpolate_dim", "target_dim", "target_var"]
DEFAULT_KWARGS_LIST = ["limit_nan_fill", "window_history_size", "window_lead_time", "statistics_per_var",
"station_type", "overwrite_local_data", "start", "end", "sampling"]
"station_type", "overwrite_local_data", "start", "end", "sampling", "transformation"]
class PreProcessing(RunEnvironment):
......@@ -36,10 +36,15 @@ class PreProcessing(RunEnvironment):
args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope="general.preprocessing")
kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope="general.preprocessing")
valid_stations = self.check_valid_stations(args, kwargs, self.data_store.get("stations", "general"), load_tmp=False)
self.calculate_transformation(args, kwargs, valid_stations, load_tmp=False)
self.data_store.set("stations", valid_stations, "general")
self.split_train_val_test()
self.report_pre_processing()
def calculate_transformation(self, args: Dict, kwargs: Dict, all_stations: List[str], load_tmp):
pass
def report_pre_processing(self):
logging.debug(20 * '##')
n_train = len(self.data_store.get('generator', 'general.train'))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment