Skip to content
Snippets Groups Projects
Commit ea62d11c authored by lukas leufen's avatar lukas leufen
Browse files

added permutation parameter to experiment setup

parents 34277ea7 022d9a73
No related branches found
No related tags found
2 merge requests!50release for v0.7.0,!44Felix issue057 permutate data for minibatches
Pipeline #30914 passed
......@@ -12,11 +12,10 @@ import numpy as np
class Distributor(keras.utils.Sequence):
def __init__(self, generator: keras.utils.Sequence, model: keras.models, batch_size: int = 256,
fit_call: bool = True, permute_data: bool = False):
permute_data: bool = False):
self.generator = generator
self.model = model
self.batch_size = batch_size
self.fit_call = fit_call
self.do_data_permutation = permute_data
def _get_model_rank(self):
......
......@@ -33,7 +33,7 @@ class ExperimentSetup(RunEnvironment):
limit_nan_fill=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None,
test_end=None, use_all_stations_on_all_data_sets=True, trainable=None, fraction_of_train=None,
experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily",
create_new_model=None):
create_new_model=None, permute_data_on_training=None):
# create run framework
super().__init__()
......@@ -45,6 +45,7 @@ class ExperimentSetup(RunEnvironment):
trainable = True
self._set_param("trainable", trainable, default=True)
self._set_param("fraction_of_training", fraction_of_train, default=0.8)
self._set_param("permute_data", permute_data_on_training, default=False, scope="general.train")
# set experiment name
exp_date = self._get_parser_args(parser_args).get("experiment_date")
......
......@@ -65,7 +65,8 @@ class Training(RunEnvironment):
:param mode: name of set, should be from ["train", "val", "test"]
"""
gen = self.data_store.get("generator", f"general.{mode}")
setattr(self, f"{mode}_set", Distributor(gen, self.model, self.batch_size))
permute_data = self.data_store.get_default("permute_data", f"general.{mode}", default=False)
setattr(self, f"{mode}_set", Distributor(gen, self.model, self.batch_size, permute_data=permute_data))
def set_generators(self) -> None:
"""
......
......@@ -37,7 +37,6 @@ class TestDistributor:
def test_init_defaults(self, distributor):
assert distributor.batch_size == 256
assert distributor.fit_call is True
assert distributor.do_data_permutation is False
def test_get_model_rank(self, distributor, model_with_minor_branch):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment