Skip to content
Snippets Groups Projects
Commit ea62d11c authored by lukas leufen's avatar lukas leufen
Browse files

added permutation parameter to experiment setup

parents 34277ea7 022d9a73
Branches
Tags
2 merge requests!50release for v0.7.0,!44Felix issue057 permutate data for minibatches
Pipeline #30914 passed
......@@ -12,11 +12,10 @@ import numpy as np
class Distributor(keras.utils.Sequence):
def __init__(self, generator: keras.utils.Sequence, model: keras.models, batch_size: int = 256,
fit_call: bool = True, permute_data: bool = False):
permute_data: bool = False):
self.generator = generator
self.model = model
self.batch_size = batch_size
self.fit_call = fit_call
self.do_data_permutation = permute_data
def _get_model_rank(self):
......
......@@ -33,7 +33,7 @@ class ExperimentSetup(RunEnvironment):
limit_nan_fill=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None,
test_end=None, use_all_stations_on_all_data_sets=True, trainable=None, fraction_of_train=None,
experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily",
create_new_model=None):
create_new_model=None, permute_data_on_training=None):
# create run framework
super().__init__()
......@@ -45,6 +45,7 @@ class ExperimentSetup(RunEnvironment):
trainable = True
self._set_param("trainable", trainable, default=True)
self._set_param("fraction_of_training", fraction_of_train, default=0.8)
self._set_param("permute_data", permute_data_on_training, default=False, scope="general.train")
# set experiment name
exp_date = self._get_parser_args(parser_args).get("experiment_date")
......
......@@ -65,7 +65,8 @@ class Training(RunEnvironment):
:param mode: name of set, should be from ["train", "val", "test"]
"""
gen = self.data_store.get("generator", f"general.{mode}")
setattr(self, f"{mode}_set", Distributor(gen, self.model, self.batch_size))
permute_data = self.data_store.get_default("permute_data", f"general.{mode}", default=False)
setattr(self, f"{mode}_set", Distributor(gen, self.model, self.batch_size, permute_data=permute_data))
def set_generators(self) -> None:
"""
......
......@@ -37,7 +37,6 @@ class TestDistributor:
def test_init_defaults(self, distributor):
assert distributor.batch_size == 256
assert distributor.fit_call is True
assert distributor.do_data_permutation is False
def test_get_model_rank(self, distributor, model_with_minor_branch):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment