From f8ed17e51561c46a29f969c2fbba2a33ecb8a579 Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Mon, 13 Feb 2023 15:24:00 +0100
Subject: [PATCH] can now use different interp methods for cams competitor
 (also at same time as separate ref)

---
 mlair/reference_models/reference_model_cams.py | 18 ++++++++++++++++--
 mlair/run_modules/pre_processing.py            | 12 +++++++++++-
 2 files changed, 27 insertions(+), 3 deletions(-)

diff --git a/mlair/reference_models/reference_model_cams.py b/mlair/reference_models/reference_model_cams.py
index 1db19c05..4c920cff 100644
--- a/mlair/reference_models/reference_model_cams.py
+++ b/mlair/reference_models/reference_model_cams.py
@@ -11,7 +11,16 @@ import pandas as pd
 
 class CAMSforecast(AbstractReferenceModel):
 
-    def __init__(self, ref_name: str, ref_store_path: str = None, data_path: str = None):
+    def __init__(self, ref_name: str, ref_store_path: str = None, data_path: str = None, interp_method: str = None):
+        """
+        Use parameters `cams_data_path` to set `data_path` and `cams_interp_method` to set `interp_method` in MLAir
+        run script.
+
+        :param ref_name:
+        :param ref_store_path:
+        :param data_path:
+        :param interp_method:
+        """
 
         super().__init__()
         self.ref_name = ref_name
@@ -22,6 +31,7 @@ class CAMSforecast(AbstractReferenceModel):
             self.data_path = os.path.abspath(".")
         else:
             self.data_path = os.path.abspath(data_path)
+        self.interp_method = interp_method
         self.file_pattern = "forecasts_%s_test.nc"
         self.time_dim = "index"
         self.ahead_dim = "ahead"
@@ -36,7 +46,11 @@ class CAMSforecast(AbstractReferenceModel):
             darray = dataset.to_array().sortby(["longitude", "latitude"])
             for station, coords in missing_stations.items():
                 lon, lat = coords["lon"], coords["lat"]
-                station_data = darray.sel(longitude=lon, latitude=lat, method="nearest", drop=True).squeeze(drop=True)
+                if self.interp_method is None:
+                    station_data = darray.sel(longitude=lon, latitude=lat, method="nearest", drop=True).squeeze(drop=True)
+                else:
+                    station_data = darray.interp(**{"longitude": lon, "latitude": lat}, method=self.interp_method)
+                    station_data = station_data.drop_vars(["longitude", "latitude"]).squeeze(drop=True)
                 station_data = station_data.expand_dims(dim={self.type_dim: [self.ref_name]}).compute()
                 station_data.coords[self.time_dim] = station_data.coords[self.time_dim] - pd.Timedelta(days=1)
                 station_data.coords[self.ahead_dim] = station_data.coords[self.ahead_dim] + 1
diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py
index fc1ae4b7..9501d36f 100644
--- a/mlair/run_modules/pre_processing.py
+++ b/mlair/run_modules/pre_processing.py
@@ -378,13 +378,23 @@ class PreProcessing(RunEnvironment):
                 elif competitor_name.lower() == "CAMS".lower():
                     logging.info("Prepare CAMS forecasts")
                     from mlair.reference_models.reference_model_cams import CAMSforecast
+                    interp_method = self.data_store.get_default("cams_interp_method", default=None)
                     data_path = self.data_store.get_default("cams_data_path", default=None)
                     path = os.path.join(self.data_store.get("competitor_path"), competitor_name)
                     stations = {}
                     for subset in ["train", "val", "test"]:
                         data_collection = self.data_store.get("data_collection", subset)
                         stations.update({str(s): s.get_coordinates() for s in data_collection if s not in stations})
-                    CAMSforecast("CAMS", ref_store_path=path, data_path=data_path).make_reference_available_locally(stations)
+                    if interp_method is None:
+                        CAMSforecast("CAMS", ref_store_path=path, data_path=data_path, interp_method=None
+                                     ).make_reference_available_locally(stations)
+                    else:
+                        competitors = remove_items(competitors, "CAMS")
+                        for method in to_list(interp_method):
+                            CAMSforecast(f"CAMS{method}", ref_store_path=path + method, data_path=data_path,
+                                         interp_method=method).make_reference_available_locally(stations)
+                            competitors.append(f"CAMS{method}")
+                        self.data_store.set("competitors", competitors)
                 else:
                     logging.info(f"No preparation required for competitor {competitor_name} as no specific instruction "
                                  f"is provided.")
-- 
GitLab