From 022d9a73622194e6d25d0f83932e3d3c140ba71c Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Mon, 2 Mar 2020 14:59:14 +0100
Subject: [PATCH] permute data on training can be set in experiment setup

---
 src/data_handling/data_distributor.py            | 3 +--
 src/run_modules/experiment_setup.py              | 3 ++-
 src/run_modules/training.py                      | 3 ++-
 test/test_data_handling/test_data_distributor.py | 1 -
 4 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/src/data_handling/data_distributor.py b/src/data_handling/data_distributor.py
index 8a872997..b1624410 100644
--- a/src/data_handling/data_distributor.py
+++ b/src/data_handling/data_distributor.py
@@ -12,11 +12,10 @@ import numpy as np
 class Distributor(keras.utils.Sequence):
 
     def __init__(self, generator: keras.utils.Sequence, model: keras.models, batch_size: int = 256,
-                 fit_call: bool = True, permute_data: bool = False):
+                 permute_data: bool = False):
         self.generator = generator
         self.model = model
         self.batch_size = batch_size
-        self.fit_call = fit_call
         self.do_data_permutation = permute_data
 
     def _get_model_rank(self):
diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py
index 48f7c13e..44a5272e 100644
--- a/src/run_modules/experiment_setup.py
+++ b/src/run_modules/experiment_setup.py
@@ -33,7 +33,7 @@ class ExperimentSetup(RunEnvironment):
                  limit_nan_fill=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None,
                  test_end=None, use_all_stations_on_all_data_sets=True, trainable=None, fraction_of_train=None,
                  experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily",
-                 create_new_model=None):
+                 create_new_model=None, permute_data_on_training=None):
 
         # create run framework
         super().__init__()
@@ -45,6 +45,7 @@ class ExperimentSetup(RunEnvironment):
             trainable = True
         self._set_param("trainable", trainable, default=True)
         self._set_param("fraction_of_training", fraction_of_train, default=0.8)
+        self._set_param("permute_data", permute_data_on_training, default=False, scope="general.train")
 
         # set experiment name
         exp_date = self._get_parser_args(parser_args).get("experiment_date")
diff --git a/src/run_modules/training.py b/src/run_modules/training.py
index 7a522af0..df60c4f2 100644
--- a/src/run_modules/training.py
+++ b/src/run_modules/training.py
@@ -65,7 +65,8 @@ class Training(RunEnvironment):
         :param mode: name of set, should be from ["train", "val", "test"]
         """
         gen = self.data_store.get("generator", f"general.{mode}")
-        setattr(self, f"{mode}_set", Distributor(gen, self.model, self.batch_size))
+        permute_data = self.data_store.get_default("permute_data", f"general.{mode}", default=False)
+        setattr(self, f"{mode}_set", Distributor(gen, self.model, self.batch_size, permute_data=permute_data))
 
     def set_generators(self) -> None:
         """
diff --git a/test/test_data_handling/test_data_distributor.py b/test/test_data_handling/test_data_distributor.py
index 109a233e..a26e76a0 100644
--- a/test/test_data_handling/test_data_distributor.py
+++ b/test/test_data_handling/test_data_distributor.py
@@ -37,7 +37,6 @@ class TestDistributor:
 
     def test_init_defaults(self, distributor):
         assert distributor.batch_size == 256
-        assert distributor.fit_call is True
         assert distributor.do_data_permutation is False
 
     def test_get_model_rank(self, distributor, model_with_minor_branch):
-- 
GitLab