From a8cc85f185cee28b5a27691a69faf9aa6d3ae545 Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Tue, 28 Jul 2020 14:27:40 +0200
Subject: [PATCH] MLAir is now independent from the window_lead_time_parameter
 (is extracted from Y shape)

---
 src/data_handling/advanced_data_handling.py |  8 ++--
 src/data_handling/bootstraps.py             | 31 +--------------
 src/helpers/__init__.py                     |  2 +-
 src/helpers/helpers.py                      |  7 ++++
 src/plotting/postprocessing_plotting.py     |  5 +--
 src/run_modules/model_setup.py              |  2 -
 src/run_modules/post_processing.py          | 18 ++++-----
 src/run_modules/pre_processing.py           | 44 ---------------------
 8 files changed, 24 insertions(+), 93 deletions(-)

diff --git a/src/data_handling/advanced_data_handling.py b/src/data_handling/advanced_data_handling.py
index 63d26bd9..ffd6b235 100644
--- a/src/data_handling/advanced_data_handling.py
+++ b/src/data_handling/advanced_data_handling.py
@@ -13,7 +13,7 @@ import datetime as dt
 import shutil
 import inspect
 
-from typing import Union, List, Tuple
+from typing import Union, List, Tuple, Dict
 import logging
 from functools import reduce
 from src.data_handling.station_preparation import StationPrep
@@ -71,7 +71,7 @@ class AbstractDataPreparation:
 
     @classmethod
     def transformation(cls, *args, **kwargs):
-        raise NotImplementedError
+        return None
 
     def get_X(self, upsampling=False, as_numpy=False):
         raise NotImplementedError
@@ -82,8 +82,8 @@ class AbstractDataPreparation:
     def get_data(self, upsampling=False, as_numpy=False):
         return self.get_X(upsampling, as_numpy), self.get_Y(upsampling, as_numpy)
 
-    def get_coordinates(self):
-        return None, None
+    def get_coordinates(self) -> Union[None, Dict]:
+        return None
 
 
 class DefaultDataPreparation(AbstractDataPreparation):
diff --git a/src/data_handling/bootstraps.py b/src/data_handling/bootstraps.py
index 4aaa3cba..d026d551 100644
--- a/src/data_handling/bootstraps.py
+++ b/src/data_handling/bootstraps.py
@@ -17,7 +17,6 @@ import os
 from collections import Iterator, Iterable
 from itertools import chain
 
-import dask.array as da
 import numpy as np
 import xarray as xr
 
@@ -69,13 +68,12 @@ class BootstrapIterator(Iterator):
             return d.values
 
     @staticmethod
-    def shuffle(data: da.array) -> da.core.Array:
+    def shuffle(data: np.ndarray) -> np.ndarray:
         """
         Shuffle randomly from given data (draw elements with replacement).
 
         :param data: data to shuffle
-        :param chunks: chunk size for dask
-        :return: shuffled data as dask core array (not computed yet)
+        :return: shuffled data as numpy array
         """
         size = data.shape
         return np.random.choice(data.reshape(-1, ), size=size)
@@ -131,28 +129,3 @@ class BootStraps(Iterable):
         prediction = xr.open_dataarray(file).sel(type=prediction_name).squeeze()
         vals = np.tile(prediction.data, (self.number_of_bootstraps, 1))
         return vals[~np.isnan(vals).any(axis=1), :]
-
-
-
-
-if __name__ == "__main__":
-
-    from src.run_modules.experiment_setup import ExperimentSetup
-    from src.run_modules.run_environment import RunEnvironment
-    from src.run_modules.pre_processing import PreProcessing
-
-    formatter = '%(asctime)s - %(levelname)s: %(message)s  [%(filename)s:%(funcName)s:%(lineno)s]'
-    logging.basicConfig(format=formatter, level=logging.INFO)
-
-    with RunEnvironment() as run_env:
-        ExperimentSetup(stations=['DEBW107', 'DEBY081', 'DEBW013'],
-                        station_type='background', trainable=True, window_history_size=9)
-        PreProcessing()
-
-        data = run_env.data_store.get("generator", "general.test")
-        number_bootstraps = 10
-
-        boots = BootStraps(data, number_bootstraps)
-        for b in boots.boot_strap_generator():
-            a, c = b
-        logging.info(f"len is {len(boots.get_boot_strap_meta())}")
diff --git a/src/helpers/__init__.py b/src/helpers/__init__.py
index 546713b3..9e2f612c 100644
--- a/src/helpers/__init__.py
+++ b/src/helpers/__init__.py
@@ -3,4 +3,4 @@
 from .testing import PyTestRegex, PyTestAllEqual
 from .time_tracking import TimeTracking, TimeTrackingWrapper
 from .logger import Logger
-from .helpers import remove_items, float_round, dict_to_xarray, to_list
+from .helpers import remove_items, float_round, dict_to_xarray, to_list, extract_value
diff --git a/src/helpers/helpers.py b/src/helpers/helpers.py
index 968ee538..b12d9028 100644
--- a/src/helpers/helpers.py
+++ b/src/helpers/helpers.py
@@ -92,3 +92,10 @@ def remove_items(obj: Union[List, Dict], items: Any):
         return remove_from_dict(obj, items)
     else:
         raise TypeError(f"{inspect.stack()[0][3]} does not support type {type(obj)}.")
+
+
+def extract_value(encapsulated_value):
+    try:
+        return extract_value(encapsulated_value[0])
+    except TypeError:
+        return encapsulated_value
diff --git a/src/plotting/postprocessing_plotting.py b/src/plotting/postprocessing_plotting.py
index 7e282022..2fe71e23 100644
--- a/src/plotting/postprocessing_plotting.py
+++ b/src/plotting/postprocessing_plotting.py
@@ -19,7 +19,6 @@ import xarray as xr
 from matplotlib.backends.backend_pdf import PdfPages
 
 from src import helpers
-from src.data_handling import DataGenerator
 from src.data_handling.iterator import DataCollection
 from src.helpers import TimeTrackingWrapper
 
@@ -881,7 +880,7 @@ class PlotAvailability(AbstractPlotClass):
 
     """
 
-    def __init__(self, generators: Dict[str, DataGenerator], plot_folder: str = ".", sampling="daily",
+    def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".", sampling="daily",
                  summary_name="data availability", time_dimension="datetime"):
         """Initialise."""
         # create standard Gantt plot for all stations (currently in single pdf file with single page)
@@ -927,7 +926,7 @@ class PlotAvailability(AbstractPlotClass):
                     plt_dict[str(station)].update({subset: t2})
         return plt_dict
 
-    def _summarise_data(self, generators: Dict[str, DataGenerator], summary_name: str):
+    def _summarise_data(self, generators: Dict[str, DataCollection], summary_name: str):
         plt_dict = {}
         for subset, data_collection in generators.items():
             all_data = None
diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py
index 7de1c7b6..ea6199c9 100644
--- a/src/run_modules/model_setup.py
+++ b/src/run_modules/model_setup.py
@@ -31,8 +31,6 @@ class ModelSetup(RunEnvironment):
         * `trainable` [.]
         * `create_new_model` [.]
         * `generator` [train]
-        * `window_lead_time` [.]
-        * `window_history_size` [.]
         * `model_class` [.]
 
     Optional objects
diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py
index f63e92ba..4b32ecb7 100644
--- a/src/run_modules/post_processing.py
+++ b/src/run_modules/post_processing.py
@@ -15,7 +15,7 @@ import xarray as xr
 
 from src.data_handling import BootStraps, KerasIterator
 from src.helpers.datastore import NameNotFoundInDataStore
-from src.helpers import TimeTracking, statistics
+from src.helpers import TimeTracking, statistics, extract_value
 from src.model_modules.linear_model import OrdinaryLeastSquaredModel
 from src.model_modules.model_class import AbstractModelClass
 from src.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, \
@@ -42,7 +42,7 @@ class PostProcessing(RunEnvironment):
         * `model_path` [.]
         * `target_var` [.]
         * `sampling` [.]
-        * `window_lead_time` [.]
+        * `output_shape` [model]
         * `evaluate_bootstraps` [postprocessing] and if enabled:
 
             * `create_new_bootstraps` [postprocessing]
@@ -74,6 +74,7 @@ class PostProcessing(RunEnvironment):
         self.plot_path: str = self.data_store.get("plot_path")
         self.target_var = self.data_store.get("target_var")
         self._sampling = self.data_store.get("sampling")
+        self.window_lead_time = extract_value(self.data_store.get("output_shape", "model"))
         self.skill_scores = None
         self.bootstrap_skill_scores = None
         self._run()
@@ -182,7 +183,6 @@ class PostProcessing(RunEnvironment):
             # extract all requirements from data store
             bootstrap_path = self.data_store.get("bootstrap_path")
             forecast_path = self.data_store.get("forecast_path")
-            window_lead_time = self.data_store.get("window_lead_time")
             number_of_bootstraps = self.data_store.get("number_of_bootstraps", "postprocessing")
             forecast_file = f"forecasts_norm_%s_test.nc"
             bootstraps = BootStraps(self.test_data[0], number_of_bootstraps).bootstraps()
@@ -203,14 +203,14 @@ class PostProcessing(RunEnvironment):
                 orig = xr.DataArray(orig, coords=coords, dims=["index", "ahead", "type"])
 
                 # calculate skill scores for each variable
-                skill = pd.DataFrame(columns=range(1, window_lead_time + 1))
+                skill = pd.DataFrame(columns=range(1, self.window_lead_time + 1))
                 for boot_set in bootstraps:
                     boot_var = f"{boot_set[0]}_{boot_set[1]}"
                     file_name = os.path.join(forecast_path, f"bootstraps_{station}_{boot_var}.nc")
                     boot_data = xr.open_dataarray(file_name)
                     boot_data = boot_data.combine_first(labels).combine_first(orig)
                     boot_scores = []
-                    for ahead in range(1, window_lead_time + 1):
+                    for ahead in range(1, self.window_lead_time + 1):
                         data = boot_data.sel(ahead=ahead)
                         boot_scores.append(
                             skill_scores.general_skill_score(data, forecast_name=boot_var, reference_name="orig"))
@@ -429,8 +429,7 @@ class PostProcessing(RunEnvironment):
         tmp_persi = data.copy()
         if not normalised:
             tmp_persi = statistics.apply_inverse_transformation(tmp_persi, mean, std, transformation_method)
-        window_lead_time = self.data_store.get("window_lead_time")
-        persistence_prediction.values = np.tile(tmp_persi, (window_lead_time, 1)).T
+        persistence_prediction.values = np.tile(tmp_persi, (self.window_lead_time, 1)).T
         return persistence_prediction
 
     def _create_nn_forecast(self, input_data: xr.DataArray, nn_prediction: xr.DataArray, mean: xr.DataArray,
@@ -547,7 +546,6 @@ class PostProcessing(RunEnvironment):
         :return: competitive and climatological skill scores
         """
         path = self.data_store.get("forecast_path")
-        window_lead_time = self.data_store.get("window_lead_time")
         skill_score_competitive = {}
         skill_score_climatological = {}
         for station in self.test_data:
@@ -555,7 +553,7 @@ class PostProcessing(RunEnvironment):
             data = xr.open_dataarray(file)
             skill_score = statistics.SkillScores(data)
             external_data = self._get_external_data(station)
-            skill_score_competitive[station] = skill_score.skill_scores(window_lead_time)
+            skill_score_competitive[station] = skill_score.skill_scores(self.window_lead_time)
             skill_score_climatological[station] = skill_score.climatological_skill_scores(external_data,
-                                                                                          window_lead_time)
+                                                                                          self.window_lead_time)
         return skill_score_competitive, skill_score_climatological
diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py
index 1b7124d0..2e78887f 100644
--- a/src/run_modules/pre_processing.py
+++ b/src/run_modules/pre_processing.py
@@ -10,7 +10,6 @@ from typing import Tuple, Dict, List
 import numpy as np
 import pandas as pd
 
-from src.data_handling import DataGenerator
 from src.data_handling import DataCollection
 from src.helpers import TimeTracking
 from src.configuration import path_config
@@ -196,49 +195,6 @@ class PreProcessing(RunEnvironment):
         self.data_store.set("stations", valid_stations, scope=set_name)
         self.data_store.set("data_collection", collection, scope=set_name)
 
-    @staticmethod
-    def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str], load_tmp=True, save_tmp=True,
-                             name=None):
-        """
-        Check if all given stations in `all_stations` are valid.
-
-        Valid means, that there is data available for the given time range (is included in `kwargs`). The shape and the
-        loading time are logged in debug mode.
-
-        :param args: Dictionary with required parameters for DataGenerator class (`data_path`, `network`, `stations`,
-            `variables`, `interpolate_dim`, `target_dim`, `target_var`).
-        :param kwargs: positional parameters for the DataGenerator class (e.g. `start`, `interpolate_method`,
-            `window_lead_time`).
-        :param all_stations: All stations to check.
-        :param name: name to display in the logging info message
-
-        :return: Corrected list containing only valid station IDs.
-        """
-        t_outer = TimeTracking()
-        t_inner = TimeTracking(start=False)
-        logging.info(f"check valid stations started{' (%s)' % name if name else ''}")
-        valid_stations = []
-
-        # all required arguments of the DataGenerator can be found in args, positional arguments in args and kwargs
-        data_gen = DataGenerator(**args, **kwargs)
-        for pos, station in enumerate(all_stations):
-            t_inner.run()
-            logging.info(f"check station {station} ({pos + 1} / {len(all_stations)})")
-            try:
-                data = data_gen.get_data_generator(key=station, load_local_tmp_storage=load_tmp,
-                                                   save_local_tmp_storage=save_tmp)
-                if data.history is None:
-                    raise AttributeError
-                valid_stations.append(station)
-                logging.debug(
-                    f'{station}: history_shape = {data.history.transpose("datetime", "window", "Stations", "variables").shape}')
-                logging.debug(f"{station}: loading time = {t_inner}")
-            except (AttributeError, EmptyQueryResult):
-                continue
-        logging.info(f"run for {t_outer} to check {len(all_stations)} station(s). Found {len(valid_stations)}/"
-                     f"{len(all_stations)} valid stations.")
-        return valid_stations
-
     def validate_station(self, data_preparation, set_stations, set_name=None, overwrite_local_data=False):
         """
         Check if all given stations in `all_stations` are valid.
-- 
GitLab