diff --git a/run.py b/run.py
index 556cd0b59ed023178fa7e6df1b5b03b9f6c5f157..1d17b4c613850464f52745811292cba58ce5de30 100644
--- a/run.py
+++ b/run.py
@@ -17,7 +17,7 @@ def main(parser_args):
 
     with RunEnvironment():
         ExperimentSetup(parser_args, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'],
-                        station_type='background', trainable=False, create_new_model=False)
+                        station_type='background', trainable=False, create_new_model=True)
         PreProcessing()
 
         ModelSetup()
diff --git a/src/data_handling/bootstraps.py b/src/data_handling/bootstraps.py
index 6b66b87552e4962ec921e6c37d509c90121f3b9b..05f4e6c2831c43ffc09586cc025c1fb457003d82 100644
--- a/src/data_handling/bootstraps.py
+++ b/src/data_handling/bootstraps.py
@@ -48,6 +48,76 @@ class BootStrapGenerator:
                         yield boot_hist, label
             return
 
+    def get_generator_station_wise(self, station):
+        """
+        This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return
+        the history and label data of this generator.
+        :return:
+        """
+        # logging.info(f"station: {station}")
+        hist, label = self.orig_generator[station]
+        shuffled_data = self.load_shuffled_data(station, self.variables)
+
+        def f():
+            while True:
+                for var in self.variables:
+                    logging.debug(f"  var: {var}")
+                    for boot in range(self.number_of_boots):
+                        logging.debug(f"boot: {boot}")
+                        boot_hist = hist.sel(variables=helpers.list_pop(self.variables, var))
+                        shuffled_var = shuffled_data.sel(variables=var, boots=boot).expand_dims("variables").drop("boots").transpose("datetime", "window", "Stations", "variables")
+                        boot_hist = boot_hist.combine_first(shuffled_var)
+                        boot_hist = boot_hist.sortby("variables")
+                        yield boot_hist
+                return
+
+        return hist, label, f, self.number_of_boots * len(self.variables)
+
+    def get_bootstrap_meta_station_wise(self, station) -> List:
+        """
+        Create meta data on ordering of variable bootstraps according to ordering from get_generator method.
+        :return: list with bootstrapped variable first and its corresponding station second.
+        """
+        bootstrap_meta = []
+        label = self.orig_generator.get_data_generator(station).get_transposed_label()
+        for var in self.variables:
+            for boot in range(self.number_of_boots):
+                bootstrap_meta.extend([[var, station]] * len(label))
+        return bootstrap_meta
+
+    def get_generator_station_var_wise(self, station, var):
+        """
+        This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return
+        the history and label data of this generator.
+        :return:
+        """
+        hist, label = self.orig_generator[station]
+        shuffled_data = self.load_shuffled_data(station, self.variables)
+
+        def f():
+            while True:
+                for boot in range(self.number_of_boots):
+                    logging.debug(f"boot: {boot}")
+                    boot_hist = hist.sel(variables=helpers.list_pop(self.variables, var))
+                    shuffled_var = shuffled_data.sel(variables=var, boots=boot).expand_dims("variables").drop("boots").transpose("datetime", "window", "Stations", "variables")
+                    boot_hist = boot_hist.combine_first(shuffled_var)
+                    boot_hist = boot_hist.sortby("variables")
+                    yield boot_hist
+                return
+
+        return hist, label, f, self.number_of_boots
+
+    def get_bootstrap_meta_station_var_wise(self, station, var) -> List:
+        """
+        Create meta data on ordering of variable bootstraps according to ordering from get_generator method.
+        :return: list with bootstrapped variable first and its corresponding station second.
+        """
+        bootstrap_meta = []
+        label = self.orig_generator.get_data_generator(station).get_transposed_label()
+        for boot in range(self.number_of_boots):
+            bootstrap_meta.extend([[var, station]] * len(label))
+        return bootstrap_meta
+
     def get_bootstrap_meta(self) -> List:
         """
         Create meta data on ordering of variable bootstraps according to ordering from get_generator method.
@@ -116,6 +186,26 @@ class BootStraps:
         self.create_shuffled_data()
         self._boot_strap_generator = BootStrapGenerator(self.data, self.number_bootstraps, self.bootstrap_path)
 
+    @property
+    def stations(self):
+        return self._boot_strap_generator.stations
+
+    @property
+    def variables(self):
+        return self._boot_strap_generator.variables
+
+    def get_generator_station_wise(self, station):
+        return self._boot_strap_generator.get_generator_station_wise(station)
+
+    def get_generator_station_var_wise(self, station, var):
+        return self._boot_strap_generator.get_generator_station_var_wise(station, var)
+
+    def get_bootstrap_meta_station_wise(self, station):
+        return self._boot_strap_generator.get_bootstrap_meta_station_wise(station)
+
+    def get_bootstrap_meta_station_var_wise(self, station, var):
+        return self._boot_strap_generator.get_bootstrap_meta_station_var_wise(station, var)
+
     def get_boot_strap_meta(self):
         return self._boot_strap_generator.get_bootstrap_meta()
 
diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py
index 32ca0d2e82af32d8164d80ac42731e10f431a458..fe22e37f1f1cfdb48d8136fe53a6ce5ee7f4be97 100644
--- a/src/run_modules/model_setup.py
+++ b/src/run_modules/model_setup.py
@@ -10,8 +10,8 @@ import tensorflow as tf
 
 from src.model_modules.keras_extensions import HistoryAdvanced, CallbackHandler
 # from src.model_modules.model_class import MyBranchedModel as MyModel
-from src.model_modules.model_class import MyLittleModel as MyModel
-# from src.model_modules.model_class import MyTowerModel as MyModel
+# from src.model_modules.model_class import MyLittleModel as MyModel
+from src.model_modules.model_class import MyTowerModel as MyModel
 from src.run_modules.run_environment import RunEnvironment
 
 
diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py
index f2d5a7d9b528beae954b2671effd6400a694f4aa..3389c659facf88383330ec0543760ce99b127aec 100644
--- a/src/run_modules/post_processing.py
+++ b/src/run_modules/post_processing.py
@@ -51,16 +51,28 @@ class PostProcessing(RunEnvironment):
             logging.info("take a look on the next reported time measure. If this increases a lot, one should think to "
                          "skip make_prediction() whenever it is possible to save time.")
 
-        # skill scores
-        self.skill_scores = self.calculate_skill_scores()
-
         # bootstraps
         if self.data_store.get("evaluate_bootstraps", "general.postprocessing"):
-            self.create_boot_straps()
-            self.bootstrap_skill_scores = self.calculate_bootstrap_skill_scores()
+            bootstrap_path = self.data_store.get("bootstrap_path", "general")
+            BootStraps(self.test_data, bootstrap_path, 20)
+            with TimeTracking(name="split (refac_1)"):
+                self.create_boot_straps_refac_2()
+                self.bootstrap_skill_scores = self.calculate_bootstrap_skill_scores()
+            with TimeTracking(name="split (refac)"):
+                self.create_boot_straps_refac()
+                self.bootstrap_skill_scores = self.calculate_bootstrap_skill_scores()
+            with TimeTracking(name="merged"):
+                self.bootstrap_skill_scores = self.combined_boot_forecast_and_skill()
+            with TimeTracking(name="original version"):
+                self.create_boot_straps()
+                self.bootstrap_skill_scores = self.calculate_bootstrap_skill_scores()
+
+
+        # skill scores
+        # self.skill_scores = self.calculate_skill_scores()
 
         # plotting
-        self.plot()
+        # self.plot()
 
     def create_boot_straps(self):
         # forecast
@@ -70,6 +82,7 @@ class PostProcessing(RunEnvironment):
             window_lead_time = self.data_store.get("window_lead_time", "general")
             bootstraps = BootStraps(self.test_data, bootstrap_path, 20)
             # make bootstrap predictions
+            logging.info("predictions")
             bootstrap_predictions = self.model.predict_generator(generator=bootstraps.boot_strap_generator(),
                                                                  steps=bootstraps.get_boot_strap_generator_length(),
                                                                  use_multiprocessing=True)
@@ -81,6 +94,7 @@ class PostProcessing(RunEnvironment):
             # save bootstrap predictions separately for each station and variable combination
             variables = np.unique(bootstrap_meta[:, 0])
             for station in np.unique(bootstrap_meta[:, 1]):
+                logging.info(station)
                 coords = None
                 for boot in variables:
                     # store each variable - station - combination
@@ -97,6 +111,81 @@ class PostProcessing(RunEnvironment):
                 labels = xr.DataArray(labels, coords=(*coords, ["obs"]), dims=["index", "ahead", "type"])
                 labels.to_netcdf(file_name)
 
+    def create_boot_straps_refac(self):
+        # forecast
+        with TimeTracking(name="boot predictions"):
+            bootstrap_path = self.data_store.get("bootstrap_path", "general")
+            forecast_path = self.data_store.get("forecast_path", "general")
+            window_lead_time = self.data_store.get("window_lead_time", "general")
+            bootstraps = BootStraps(self.test_data, bootstrap_path, 20)
+            for station in bootstraps.stations:
+                with TimeTracking(name=station):
+                    logging.info(station)
+                    hist, label, station_bootstrap, length = bootstraps.get_generator_station_wise(station)
+
+                    # make bootstrap predictions
+                    bootstrap_predictions = self.model.predict_generator(generator=station_bootstrap(),
+                                                                         steps=length,
+                                                                         use_multiprocessing=True)
+                    if isinstance(bootstrap_predictions, list):
+                        bootstrap_predictions = bootstrap_predictions[-1]
+                    # get bootstrap prediction meta data
+                    bootstrap_meta = np.array(bootstraps.get_bootstrap_meta_station_wise(station))
+                    # save bootstrap predictions separately for each station and variable combination
+                    variables = np.unique(bootstrap_meta[:, 0])
+                    coords = None
+                    for boot in variables:
+                        # store each variable - station - combination
+                        ind = np.all(bootstrap_meta == [boot, station], axis=1)
+                        length = sum(ind)
+                        sel = bootstrap_predictions[ind].reshape((length, window_lead_time, 1))
+                        coords = (range(length), range(1, window_lead_time + 1))
+                        tmp = xr.DataArray(sel, coords=(*coords, [boot]), dims=["index", "ahead", "type"])
+                        file_name = os.path.join(forecast_path, f"bootstraps_{boot}_{station}.nc")
+                        tmp.to_netcdf(file_name)
+                    # store also true labels for each station
+                    labels = bootstraps.get_labels(station).reshape((length, window_lead_time, 1))
+                    file_name = os.path.join(forecast_path, f"bootstraps_labels_{station}.nc")
+                    labels = xr.DataArray(labels, coords=(*coords, ["obs"]), dims=["index", "ahead", "type"])
+                    labels.to_netcdf(file_name)
+
+    def create_boot_straps_refac_2(self):
+        # forecast
+        with TimeTracking(name="boot predictions"):
+            bootstrap_path = self.data_store.get("bootstrap_path", "general")
+            forecast_path = self.data_store.get("forecast_path", "general")
+            window_lead_time = self.data_store.get("window_lead_time", "general")
+            bootstraps = BootStraps(self.test_data, bootstrap_path, 20)
+
+            for station in bootstraps.stations:
+                with TimeTracking(name=station):
+                    logging.info(station)
+                    for var in bootstraps.variables:
+                        hist, label, station_bootstrap, length = bootstraps.get_generator_station_var_wise(station, var)
+
+                        # make bootstrap predictions
+                        bootstrap_predictions = self.model.predict_generator(generator=station_bootstrap(),
+                                                                             steps=length,
+                                                                             use_multiprocessing=True)
+                        if isinstance(bootstrap_predictions, list):
+                            bootstrap_predictions = bootstrap_predictions[-1]
+                        # get bootstrap prediction meta data
+                        bootstrap_meta = np.array(bootstraps.get_bootstrap_meta_station_var_wise(station, var))
+                        # save bootstrap predictions separately for each station and variable combination
+                        # store each variable - station - combination
+                        ind = np.all(bootstrap_meta == [var, station], axis=1)
+                        length = sum(ind)
+                        sel = bootstrap_predictions[ind].reshape((length, window_lead_time, 1))
+                        coords = (range(length), range(1, window_lead_time + 1))
+                        tmp = xr.DataArray(sel, coords=(*coords, [var]), dims=["index", "ahead", "type"])
+                        file_name = os.path.join(forecast_path, f"bootstraps_{var}_{station}.nc")
+                        tmp.to_netcdf(file_name)
+                    # store also true labels for each station
+                    labels = bootstraps.get_labels(station).reshape((length, window_lead_time, 1))
+                    file_name = os.path.join(forecast_path, f"bootstraps_labels_{station}.nc")
+                    labels = xr.DataArray(labels, coords=(*coords, ["obs"]), dims=["index", "ahead", "type"])
+                    labels.to_netcdf(file_name)
+
     def calculate_bootstrap_skill_scores(self):
 
         with TimeTracking(name="boot skill scores"):
@@ -110,6 +199,7 @@ class PostProcessing(RunEnvironment):
             skill_scores = statistics.SkillScores(None)
             score = {}
             for station in self.test_data.stations:
+                logging.info(station)
                 file_name = os.path.join(forecast_path, f"bootstraps_labels_{station}.nc")
                 labels = xr.open_dataarray(file_name)
                 shape = labels.shape
@@ -129,6 +219,62 @@ class PostProcessing(RunEnvironment):
                 score[station] = xr.DataArray(skill, dims=["boot_var", "ahead"])
             return score
 
+    def combined_boot_forecast_and_skill(self):
+        # forecast
+        with TimeTracking(name="boot predictions"):
+            bootstrap_path = self.data_store.get("bootstrap_path", "general")
+            forecast_path = self.data_store.get("forecast_path", "general")
+            window_lead_time = self.data_store.get("window_lead_time", "general")
+            bootstraps = BootStraps(self.test_data, bootstrap_path, 20)
+            skill_scores = statistics.SkillScores(None)
+            score = {}
+
+            for station in bootstraps.stations:
+                with TimeTracking(name=station):
+                    logging.info(station)
+                    # store also true labels for each station
+                    labels = bootstraps.get_labels(station)
+                    shape = labels.shape
+                    labels = labels.reshape((*shape, 1))
+                    coords = (range(labels.shape[0]), range(1, labels.shape[1] + 1))
+                    # file_name = os.path.join(forecast_path, f"bootstraps_labels_{station}.nc")
+                    labels = xr.DataArray(labels, coords=(*coords, ["obs"]), dims=["index", "ahead", "type"])
+                    # labels.to_netcdf(file_name)
+                    shape = labels.shape
+                    orig = bootstraps.get_orig_prediction(forecast_path,  f"forecasts_norm_{station}_test.nc").reshape(shape)
+                    coords = (range(shape[0]), range(1, shape[1] + 1), ["orig"])
+                    orig = xr.DataArray(orig, coords=coords, dims=["index", "ahead", "type"])
+                    skill = pd.DataFrame(columns=range(1, window_lead_time + 1))
+                    for var in bootstraps.variables:
+                        hist, label, station_bootstrap, length = bootstraps.get_generator_station_var_wise(station, var)
+
+                        # make bootstrap predictions
+                        bootstrap_predictions = self.model.predict_generator(generator=station_bootstrap(),
+                                                                             steps=length,
+                                                                             use_multiprocessing=True)
+                        if isinstance(bootstrap_predictions, list):
+                            bootstrap_predictions = bootstrap_predictions[-1]
+                        # get bootstrap prediction meta data
+                        bootstrap_meta = np.array(bootstraps.get_bootstrap_meta_station_var_wise(station, var))
+                        # save bootstrap predictions separately for each station and variable combination
+                        # store each variable - station - combination
+                        ind = np.all(bootstrap_meta == [var, station], axis=1)
+                        length = sum(ind)
+                        sel = bootstrap_predictions[ind].reshape((length, window_lead_time, 1))
+                        coords = (range(length), range(1, window_lead_time + 1))
+                        boot_data = xr.DataArray(sel, coords=(*coords, [var]), dims=["index", "ahead", "type"])
+                        # file_name = os.path.join(forecast_path, f"bootstraps_{var}_{station}.nc")
+                        # boot_data.to_netcdf(file_name)
+                        boot_data = boot_data.combine_first(labels).combine_first(orig)
+                        boot_scores = []
+                        for ahead in range(1, window_lead_time + 1):
+                            data = boot_data.sel(ahead=ahead)
+                            boot_scores.append(skill_scores.general_skill_score(data, forecast_name=var, reference_name="orig"))
+                        skill.loc[var] = np.array(boot_scores)
+                    score[station] = xr.DataArray(skill, dims=["boot_var", "ahead"])
+            return score
+
+
     def _load_model(self):
         try:
             model = self.data_store.get("best_model", "general")