From 59f09e79e82ff617326ada41b84c6949544c7085 Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Thu, 29 Jun 2023 10:59:22 +0200
Subject: [PATCH] can now set path from outside for era5 and ifs

---
 mlair/configuration/era5_settings.py          |  6 +++---
 mlair/configuration/ifs_settings.py           |  6 +++---
 .../data_handler_single_station.py            | 19 ++++++++++++++++---
 mlair/helpers/data_sources/data_loader.py     | 10 +++++++---
 mlair/helpers/data_sources/era5.py            |  5 +++--
 mlair/helpers/data_sources/ifs.py             |  5 +++--
 mlair/run_modules/pre_processing.py           |  3 ++-
 7 files changed, 37 insertions(+), 17 deletions(-)

diff --git a/mlair/configuration/era5_settings.py b/mlair/configuration/era5_settings.py
index 9f44176b..7b09bd83 100644
--- a/mlair/configuration/era5_settings.py
+++ b/mlair/configuration/era5_settings.py
@@ -3,7 +3,7 @@
 from typing import Tuple
 
 
-def era5_settings(sampling="daily") -> Tuple[str, str]:
+def era5_settings(sampling="daily", era5_data_path=None, era5_file_names=None) -> Tuple[str, str]:
     """
     Check for sampling as only hourly resolution is supported by era5 and return path on HPC systems.
 
@@ -12,8 +12,8 @@ def era5_settings(sampling="daily") -> Tuple[str, str]:
     :return: HPC path
     """
     if sampling == "hourly":  # pragma: no branch
-        ERA5_DATA_PATH = "."
-        FILE_NAMES = "*.nc"
+        ERA5_DATA_PATH = era5_data_path or "."
+        FILE_NAMES = era5_file_names or "*.nc"
     else:
         raise NameError(f"Given sampling {sampling} is not supported, only hourly sampling can be used.")
     return ERA5_DATA_PATH, FILE_NAMES
diff --git a/mlair/configuration/ifs_settings.py b/mlair/configuration/ifs_settings.py
index f0e8ac49..dd40d72e 100644
--- a/mlair/configuration/ifs_settings.py
+++ b/mlair/configuration/ifs_settings.py
@@ -3,7 +3,7 @@
 from typing import Tuple
 
 
-def ifs_settings(sampling="daily") -> Tuple[str, str]:
+def ifs_settings(sampling="daily", ifs_data_path=None, ifs_file_names=None) -> Tuple[str, str]:
     """
     Check for sampling as only hourly resolution is supported by ifs and return path on HPC systems.
 
@@ -12,8 +12,8 @@ def ifs_settings(sampling="daily") -> Tuple[str, str]:
     :return: HPC path
     """
     if sampling == "hourly":  # pragma: no branch
-        IFS_DATA_PATH = "."
-        FILE_NAMES = "*.nc"
+        IFS_DATA_PATH = ifs_data_path or "."
+        FILE_NAMES = ifs_file_names or "*.nc"
     else:
         raise NameError(f"Given sampling {sampling} is not supported, only hourly sampling can be used.")
     return IFS_DATA_PATH, FILE_NAMES
diff --git a/mlair/data_handler/data_handler_single_station.py b/mlair/data_handler/data_handler_single_station.py
index e238deed..54093ab4 100644
--- a/mlair/data_handler/data_handler_single_station.py
+++ b/mlair/data_handler/data_handler_single_station.py
@@ -71,7 +71,8 @@ class DataHandlerSingleStation(AbstractDataHandler):
                  interpolation_method: Union[str, Tuple[str]] = DEFAULT_INTERPOLATION_METHOD,
                  overwrite_local_data: bool = False, transformation=None, store_data_locally: bool = True,
                  min_length: int = 0, start=None, end=None, variables=None, data_origin: Dict = None,
-                 lazy_preprocessing: bool = False, overwrite_lazy_data=False, **kwargs):
+                 lazy_preprocessing: bool = False, overwrite_lazy_data=False, era5_data_path=None, era5_file_names=None,
+                 ifs_data_path=None, ifs_file_names=None, **kwargs):
         super().__init__()
         self.station = helpers.to_list(station)
         self.path = self.setup_data_path(data_path, sampling)
@@ -115,6 +116,11 @@ class DataHandlerSingleStation(AbstractDataHandler):
         self.label = None
         self.observation = None
 
+        self._era5_data_path = era5_data_path
+        self._era5_file_names = era5_file_names
+        self._ifs_data_path = ifs_data_path
+        self._ifs_file_names = ifs_file_names
+
         # create samples
         self.setup_samples()
         self.clean_up()
@@ -343,7 +349,11 @@ class DataHandlerSingleStation(AbstractDataHandler):
             data, meta = data_sources.download_data(file_name, meta_file, station, statistics_per_var, sampling,
                                                     store_data_locally=store_data_locally, data_origin=data_origin,
                                                     time_dim=self.time_dim, target_dim=self.target_dim,
-                                                    iter_dim=self.iter_dim, window_dim=self.window_dim)
+                                                    iter_dim=self.iter_dim, window_dim=self.window_dim,
+                                                    era5_data_path=self._era5_data_path,
+                                                    era5_file_names=self._era5_file_names,
+                                                    ifs_data_path=self._ifs_data_path,
+                                                    ifs_file_names=self._ifs_file_names)
             logging.debug(f"{self.station[0]}: loaded new data")
         else:
             try:
@@ -358,7 +368,10 @@ class DataHandlerSingleStation(AbstractDataHandler):
                 data, meta = data_sources.download_data(file_name, meta_file, station, statistics_per_var, sampling,
                                                         store_data_locally=store_data_locally, data_origin=data_origin,
                                                         time_dim=self.time_dim, target_dim=self.target_dim,
-                                                        iter_dim=self.iter_dim)
+                                                        iter_dim=self.iter_dim,  era5_data_path=self._era5_data_path,
+                                                        era5_file_names=self._era5_file_names,
+                                                        ifs_data_path=self._ifs_data_path,
+                                                        ifs_file_names=self._ifs_file_names)
                 logging.debug(f"{self.station[0]}: loading finished")
         # create slices and check for negative concentration.
         data = self._slice_prep(data, start=start, end=end)
diff --git a/mlair/helpers/data_sources/data_loader.py b/mlair/helpers/data_sources/data_loader.py
index a3b50746..fd0b5042 100644
--- a/mlair/helpers/data_sources/data_loader.py
+++ b/mlair/helpers/data_sources/data_loader.py
@@ -22,7 +22,9 @@ DEFAULT_WINDOW_DIM = "window"
 
 def download_data(file_name: str, meta_file: str, station, statistics_per_var, sampling,
                   store_data_locally=True, data_origin: Dict = None, time_dim=DEFAULT_TIME_DIM,
-                  target_dim=DEFAULT_TARGET_DIM, iter_dim=DEFAULT_ITER_DIM, window_dim=DEFAULT_WINDOW_DIM) -> [xr.DataArray, pd.DataFrame]:
+                  target_dim=DEFAULT_TARGET_DIM, iter_dim=DEFAULT_ITER_DIM, window_dim=DEFAULT_WINDOW_DIM,
+                  era5_data_path=None, era5_file_names=None,  ifs_data_path=None, ifs_file_names=None) -> \
+        [xr.DataArray, pd.DataFrame]:
     """
     Download data from TOAR database using the JOIN interface or load local era5 data.
 
@@ -54,12 +56,14 @@ def download_data(file_name: str, meta_file: str, station, statistics_per_var, s
         # load era5 data
         df_era5_local, meta_era5_local = data_sources.era5.load_era5(
             station_name=station, stat_var=era5_local_stats, sampling=sampling, data_origin=era5_local_origin,
-            window_dim=window_dim, time_dim=time_dim, target_dim=target_dim)
+            window_dim=window_dim, time_dim=time_dim, target_dim=target_dim, era5_data_path=era5_data_path,
+            era5_file_names=era5_file_names)
     if ifs_local_origin is not None and len(ifs_local_stats) > 0:
         # load era5 data
         df_ifs_local, meta_ifs_local = data_sources.ifs.load_ifs(
             station_name=station, stat_var=ifs_local_stats, sampling=sampling, data_origin=ifs_local_origin,
-            lead_time_dim=window_dim, initial_time_dim=time_dim, target_dim=target_dim)
+            lead_time_dim=window_dim, initial_time_dim=time_dim, target_dim=target_dim, ifs_data_path=ifs_data_path,
+            ifs_file_names=ifs_file_names)
     if toar_origin is None or len(toar_stats) > 0:
         # load combined data from toar-data (v2 & v1)
         df_toar, meta_toar = data_sources.toar_data.download_toar(station=station, toar_stats=toar_stats,
diff --git a/mlair/helpers/data_sources/era5.py b/mlair/helpers/data_sources/era5.py
index 156df00d..14646231 100644
--- a/mlair/helpers/data_sources/era5.py
+++ b/mlair/helpers/data_sources/era5.py
@@ -17,13 +17,14 @@ from mlair.helpers.data_sources.data_loader import EmptyQueryResult
 from mlair.helpers.meteo import relative_humidity_from_dewpoint
 
 
-def load_era5(station_name, stat_var, sampling, data_origin, time_dim, window_dim, target_dim):
+def load_era5(station_name, stat_var, sampling, data_origin, time_dim, window_dim, target_dim, era5_data_path=None,
+              era5_file_names=None):
 
     # make sure station_name parameter is a list
     station_name = helpers.to_list(station_name)
 
     # get data path
-    data_path, file_names = era5_settings(sampling)
+    data_path, file_names = era5_settings(sampling, era5_data_path=era5_data_path, era5_file_names=era5_file_names)
 
     # correct stat_var values if data is not aggregated (hourly)
     if sampling == "hourly":
diff --git a/mlair/helpers/data_sources/ifs.py b/mlair/helpers/data_sources/ifs.py
index eba40dcc..ae16c33b 100644
--- a/mlair/helpers/data_sources/ifs.py
+++ b/mlair/helpers/data_sources/ifs.py
@@ -20,13 +20,14 @@ from mlair.helpers.data_sources.data_loader import EmptyQueryResult
 from mlair.helpers.meteo import relative_humidity_from_dewpoint
 
 
-def load_ifs(station_name, stat_var, sampling, data_origin, lead_time_dim, initial_time_dim, target_dim):
+def load_ifs(station_name, stat_var, sampling, data_origin, lead_time_dim, initial_time_dim, target_dim,
+             ifs_data_path=None, ifs_file_names=None):
 
     # make sure station_name parameter is a list
     station_name = helpers.to_list(station_name)
 
     # get data path
-    data_path, file_names = ifs_settings(sampling)
+    data_path, file_names = ifs_settings(sampling, ifs_data_path=ifs_data_path, ifs_file_names=ifs_file_names)
 
     # correct stat_var values if data is not aggregated (hourly)
     if sampling == "hourly":
diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py
index d56d064a..dffabffe 100644
--- a/mlair/run_modules/pre_processing.py
+++ b/mlair/run_modules/pre_processing.py
@@ -461,7 +461,8 @@ class PreProcessing(RunEnvironment):
                            "neighbors", "plot_list", "plot_path", "regularizer", "restore_best_model_weights",
                            "snapshot_load_path", "snapshot_path", "stations", "tmp_path", "train_model",
                            "transformation", "use_multiprocessing", "cams_data_path", "cams_interp_method", 
-                           "do_bias_free_evaluation", "apriori_file", "model_path", "model_load_path"]
+                           "do_bias_free_evaluation", "apriori_file", "model_path", "model_load_path", "era5_data_path",
+                           "era5_file_names", "ifs_data_path", "ifs_file_names"]
         data_handler = self.data_store.get("data_handler")
         model_class = self.data_store.get("model_class")
         excluded_params = list(set(excluded_params + data_handler.store_attributes() + model_class.requirements()))
-- 
GitLab