diff --git a/src/data_handling/data_distributor.py b/src/data_handling/data_distributor.py index 8a872997877536f948483b66c90db30c1c849f3d..b1624410e746ab779b20a60d6a7d19b4ae3b1267 100644 --- a/src/data_handling/data_distributor.py +++ b/src/data_handling/data_distributor.py @@ -12,11 +12,10 @@ import numpy as np class Distributor(keras.utils.Sequence): def __init__(self, generator: keras.utils.Sequence, model: keras.models, batch_size: int = 256, - fit_call: bool = True, permute_data: bool = False): + permute_data: bool = False): self.generator = generator self.model = model self.batch_size = batch_size - self.fit_call = fit_call self.do_data_permutation = permute_data def _get_model_rank(self): diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py index 48f7c13e51622d7d52405b73c0a6f57537b5b476..44a5272ed99b8072168ac7cced7047feda8c2487 100644 --- a/src/run_modules/experiment_setup.py +++ b/src/run_modules/experiment_setup.py @@ -33,7 +33,7 @@ class ExperimentSetup(RunEnvironment): limit_nan_fill=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, test_end=None, use_all_stations_on_all_data_sets=True, trainable=None, fraction_of_train=None, experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily", - create_new_model=None): + create_new_model=None, permute_data_on_training=None): # create run framework super().__init__() @@ -45,6 +45,7 @@ class ExperimentSetup(RunEnvironment): trainable = True self._set_param("trainable", trainable, default=True) self._set_param("fraction_of_training", fraction_of_train, default=0.8) + self._set_param("permute_data", permute_data_on_training, default=False, scope="general.train") # set experiment name exp_date = self._get_parser_args(parser_args).get("experiment_date") diff --git a/src/run_modules/training.py b/src/run_modules/training.py index 7a522af0298bcabee62579f68bd29ed123cac7b0..df60c4f2f8dff4a9acb82920ad3c1d203813033d 100644 --- a/src/run_modules/training.py +++ b/src/run_modules/training.py @@ -65,7 +65,8 @@ class Training(RunEnvironment): :param mode: name of set, should be from ["train", "val", "test"] """ gen = self.data_store.get("generator", f"general.{mode}") - setattr(self, f"{mode}_set", Distributor(gen, self.model, self.batch_size)) + permute_data = self.data_store.get_default("permute_data", f"general.{mode}", default=False) + setattr(self, f"{mode}_set", Distributor(gen, self.model, self.batch_size, permute_data=permute_data)) def set_generators(self) -> None: """ diff --git a/test/test_data_handling/test_data_distributor.py b/test/test_data_handling/test_data_distributor.py index 109a233ebe4d354bc03359cf5acec81d0f8ebac0..a26e76a0e7f3ef0f5cdbedc07d73a690116966c9 100644 --- a/test/test_data_handling/test_data_distributor.py +++ b/test/test_data_handling/test_data_distributor.py @@ -37,7 +37,6 @@ class TestDistributor: def test_init_defaults(self, distributor): assert distributor.batch_size == 256 - assert distributor.fit_call is True assert distributor.do_data_permutation is False def test_get_model_rank(self, distributor, model_with_minor_branch):