Skip to content
Snippets Groups Projects
Commit 3c05a5bd authored by lukas leufen's avatar lukas leufen
Browse files

added min_length parameter to data handling, currently only supported for test set

parent b146b12b
No related branches found
No related tags found
3 merge requests!90WIP: new release update,!89Resolve "release branch / CI on gpu",!68Resolve "Minimal sample size in train-, val, test-set"
Pipeline #31754 passed
......@@ -353,7 +353,9 @@ class DataPrep(object):
non_nan_observation = self.observation.dropna(dim=dim)
intersect = reduce(np.intersect1d, (non_nan_history.coords[dim].values, non_nan_label.coords[dim].values, non_nan_observation.coords[dim].values))
if len(intersect) == 0:
min_length = self.kwargs.get("min_length", 0)
length = len(intersect)
if len(intersect) < max(min_length, 1):
self.history = None
self.label = None
self.observation = None
......
......@@ -34,7 +34,8 @@ class ExperimentSetup(RunEnvironment):
limit_nan_fill=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None,
test_end=None, use_all_stations_on_all_data_sets=True, trainable=None, fraction_of_train=None,
experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily",
create_new_model=None, bootstrap_path=None, permute_data_on_training=None, transformation=None):
create_new_model=None, bootstrap_path=None, permute_data_on_training=None, transformation=None,
test_min_length=None):
# create run framework
super().__init__()
......@@ -107,6 +108,7 @@ class ExperimentSetup(RunEnvironment):
# test set parameters
self._set_param("start", test_start, default="2010-01-01", scope="general.test")
self._set_param("end", test_end, default="2017-12-31", scope="general.test")
self._set_param("min_length", test_min_length, default=30, scope="general.test")
# train_val set parameters
self._set_param("start", self.data_store.get("start", "general.train"), scope="general.train_val")
......
......@@ -11,7 +11,7 @@ from src.join import EmptyQueryResult
from src.run_modules.run_environment import RunEnvironment
DEFAULT_ARGS_LIST = ["data_path", "network", "stations", "variables", "interpolate_dim", "target_dim", "target_var"]
DEFAULT_KWARGS_LIST = ["limit_nan_fill", "window_history_size", "window_lead_time", "statistics_per_var",
DEFAULT_KWARGS_LIST = ["limit_nan_fill", "window_history_size", "window_lead_time", "statistics_per_var", "min_length",
"station_type", "overwrite_local_data", "start", "end", "sampling", "transformation"]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment