From 1ff35adf892d5b25d4d067b8f4f4e98fcfd01d54 Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Fri, 24 Jan 2020 11:53:38 +0100
Subject: [PATCH] minor change in data loading process, there can be now the
 option to re-download data during preprocessing

---
 src/data_handling/data_preparation.py | 56 ++++++++++++++-------------
 src/run_modules/README.md             | 10 ++---
 src/run_modules/experiment_setup.py   |  3 +-
 src/run_modules/pre_processing.py     |  7 ++--
 4 files changed, 41 insertions(+), 35 deletions(-)

diff --git a/src/data_handling/data_preparation.py b/src/data_handling/data_preparation.py
index db800f5e..d0d89438 100644
--- a/src/data_handling/data_preparation.py
+++ b/src/data_handling/data_preparation.py
@@ -1,7 +1,6 @@
 __author__ = 'Felix Kleinert, Lukas Leufen'
 __date__ = '2019-10-16'
 
-
 import xarray as xr
 import pandas as pd
 import numpy as np
@@ -12,7 +11,6 @@ from src import statistics
 from typing import Union, List, Iterable
 import datetime as dt
 
-
 # define a more general date type for type hinting
 date = Union[dt.date, dt.datetime]
 
@@ -72,35 +70,42 @@ class DataPrep(object):
         Load data and meta data either from local disk (preferred) or download new data from TOAR database if no local
         data is  available. The latter case, store downloaded data locally if wished (default yes).
         """
-
         helpers.check_path_and_create(self.path)
         file_name = self._set_file_name()
         meta_file = self._set_meta_file_name()
-        try:
-
-            logging.debug(f"try to load local data from: {file_name}")
-            data = self._slice_prep(xr.open_dataarray(file_name))
-            self.data = self.check_for_negative_concentrations(data)
-            self.meta = pd.read_csv(meta_file, index_col=0)
-            if self.station_type is not None:
-                self.check_station_meta()
-            logging.debug("loading finished")
-        except FileNotFoundError as e:
-            logging.warning(e)
-            data, self.meta = self.download_data_from_join(file_name, meta_file)
-            data = self._slice_prep(data)
-            self.data = self.check_for_negative_concentrations(data)
+        if self.kwargs.get('overwrite_local_data', False):
+            logging.debug(f"overwrite_local_data is true, therefore reload {file_name} from JOIN")
+            if os.path.exists(file_name):
+                os.remove(file_name)
+            if os.path.exists(meta_file):
+                os.remove(meta_file)
+            self.download_data(file_name, meta_file)
             logging.debug("loaded new data from JOIN")
+        else:
+            try:
+                logging.debug(f"try to load local data from: {file_name}")
+                data = self._slice_prep(xr.open_dataarray(file_name))
+                self.data = self.check_for_negative_concentrations(data)
+                self.meta = pd.read_csv(meta_file, index_col=0)
+                if self.station_type is not None:
+                    self.check_station_meta()
+                logging.debug("loading finished")
+            except FileNotFoundError as e:
+                logging.warning(e)
+                self.download_data(file_name, meta_file)
+                logging.debug("loaded new data from JOIN")
+
+    def download_data(self, file_name, meta_file):
+        data, self.meta = self.download_data_from_join(file_name, meta_file)
+        data = self._slice_prep(data)
+        self.data = self.check_for_negative_concentrations(data)
 
     def check_station_meta(self):
         """
         Search for the entries in meta data and compare the value with the requested values. Raise a FileNotFoundError
         if the values mismatch.
         """
-        check_dict = {
-            "station_type": self.station_type,
-            "network_name": self.network
-        }
+        check_dict = {"station_type": self.station_type, "network_name": self.network}
         for (k, v) in check_dict.items():
             if self.meta.at[k, self.station[0]] != v:
                 logging.debug(f"meta data does not agree which given request for {k}: {v} (requested) != "
@@ -138,8 +143,8 @@ class DataPrep(object):
         return f"Dataprep(path='{self.path}', network='{self.network}', station={self.station}, " \
                f"variables={self.variables}, station_type={self.station_type}, **{self.kwargs})"
 
-    def interpolate(self, dim: str, method: str = 'linear', limit: int = None,
-                    use_coordinate: Union[bool, str] = True, **kwargs):
+    def interpolate(self, dim: str, method: str = 'linear', limit: int = None, use_coordinate: Union[bool, str] = True,
+                    **kwargs):
         """
         (Copy paste from dataarray.interpolate_na)
         Interpolate values according to different methods.
@@ -193,6 +198,7 @@ class DataPrep(object):
         Perform inverse transformation
         :return:
         """
+
         def f_inverse(data, mean, std, method_inverse):
             if method_inverse == 'standardise':
                 return statistics.standardise_inverse(data, mean, std), None, None
@@ -319,8 +325,7 @@ class DataPrep(object):
         if (self.history is not None) and (self.label is not None):
             non_nan_history = self.history.dropna(dim=dim)
             non_nan_label = self.label.dropna(dim=dim)
-            intersect = np.intersect1d(non_nan_history.coords[dim].values,
-                                       non_nan_label.coords[dim].values)
+            intersect = np.intersect1d(non_nan_history.coords[dim].values, non_nan_label.coords[dim].values)
 
         if len(intersect) == 0:
             self.history = None
@@ -382,6 +387,5 @@ class DataPrep(object):
 
 
 if __name__ == "__main__":
-
     dp = DataPrep('data/', 'dummy', 'DEBW107', ['o3', 'temp'], statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})
     print(dp)
diff --git a/src/run_modules/README.md b/src/run_modules/README.md
index eab10c72..33149220 100644
--- a/src/run_modules/README.md
+++ b/src/run_modules/README.md
@@ -4,7 +4,7 @@ This readme declares which function loads which data and where it is stored.
 
 ## experiment setup
 
-*Data_path* is the destination where all downloaded data is locally stored. Data is downloaded from TOARDB either using 
+**data_path** is the destination where all downloaded data is locally stored. Data is downloaded from TOARDB either using 
 the JOIN interface or a direct connection to the underlying PostgreSQL DB. If data was already downloaded, no new 
 download will be started. Missing data will be downloaded on the fly and saved in data_path. 
 
@@ -21,7 +21,7 @@ download will be started. Missing data will be downloaded on the fly and saved i
  | juwels | `/p/home/jusers/{user}/juwels/intelliaq/DATA/toar_daily/` | JUWELS |
  | runner-6HmDp9Qd-project-2411-concurrent | `/home/{user}/machinelearningtools/data/toar_daily/` | gitlab-runner |
 
-*experiment_path* is the root folder in that all results from the experiment are saved. For each experiment there should
+**experiment_path** is the root folder in that all results from the experiment are saved. For each experiment there should
 be distinct folder. Experiment path is can be set in ExperimentSetup. `experiment_date` can be set by parser_args and 
 `experiment_path` (this argument is not the same as the internal stored experiment_path!) as args. The *experiment_path*
 is the combination of both given arguments `os.path.join(experiment_path, f"{experiment_date}_network")`. Inside this
@@ -57,10 +57,10 @@ experiment_path
 
 ```
 
-*plot_path* includes all created plots. If not given, this is create into the experiment_path by default (as shown in 
+**plot_path** includes all created plots. If not given, this is create into the experiment_path by default (as shown in 
 the folder structure above). Can be customised by `ExperimentSetup(plot_path=<path>)`.
 
-*forecast_path* is the place, where all forecasts are stored as netcdf file. Each file consists exactly one single
+**forecast_path** is the place, where all forecasts are stored as netcdf file. Each file consists exactly one single
 station. If not given, this is create into the experiment_path by default (as shown in the folder structure above). Can 
 be customised by `ExperimentSetup(forecast_path=<path>)`.
 
@@ -77,7 +77,7 @@ in `experiment_setup.py` to overwrite local data by downloading new data.
 
 ## model setup
 
-*checkpoint* is created inside *experiment_path* as `<experiment_name>_model-best.h5`.
+**checkpoint** is created inside *experiment_path* as `<experiment_name>_model-best.h5`.
 
 The architecture of the model is plotted into *experiment_path* as `<experiment_name>_model.pdf` 
 
diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py
index a46e2b17..cc2c71f9 100644
--- a/src/run_modules/experiment_setup.py
+++ b/src/run_modules/experiment_setup.py
@@ -33,7 +33,7 @@ class ExperimentSetup(RunEnvironment):
                  window_lead_time=None, dimensions=None, interpolate_dim=None, interpolate_method=None,
                  limit_nan_fill=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None,
                  test_end=None, use_all_stations_on_all_data_sets=True, trainable=False, fraction_of_train=None,
-                 experiment_path=None, plot_path=None, forecast_path=None):
+                 experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None):
 
         # create run framework
         super().__init__()
@@ -70,6 +70,7 @@ class ExperimentSetup(RunEnvironment):
         self._set_param("start", start, default="1997-01-01", scope="general")
         self._set_param("end", end, default="2017-12-31", scope="general")
         self._set_param("window_history_size", window_history_size, default=13)
+        self._set_param("overwrite_local_data", overwrite_local_data, default=False, scope="general.preprocessing")
 
         # target
         self._set_param("target_var", target_var, default="o3")
diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py
index 8e11877f..6ab1f0dd 100644
--- a/src/run_modules/pre_processing.py
+++ b/src/run_modules/pre_processing.py
@@ -12,7 +12,8 @@ from src.join import EmptyQueryResult
 
 
 DEFAULT_ARGS_LIST = ["data_path", "network", "stations", "variables", "interpolate_dim", "target_dim", "target_var"]
-DEFAULT_KWARGS_LIST = ["limit_nan_fill", "window_history_size", "window_lead_time", "statistics_per_var", "station_type"]
+DEFAULT_KWARGS_LIST = ["limit_nan_fill", "window_history_size", "window_lead_time", "statistics_per_var",
+                       "station_type", "overwrite_local_data"]
 
 
 class PreProcessing(RunEnvironment):
@@ -33,8 +34,8 @@ class PreProcessing(RunEnvironment):
         self._run()
 
     def _run(self):
-        args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST)
-        kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST)
+        args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope="general.preprocessing")
+        kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope="general.preprocessing")
         valid_stations = self.check_valid_stations(args, kwargs, self.data_store.get("stations", "general"))
         self.data_store.put("stations", valid_stations, "general")
         self.split_train_val_test()
-- 
GitLab