From 4efbe26e09c63bca2fc30d02b1accab814a72240 Mon Sep 17 00:00:00 2001
From: Felix Kleinert <f.kleinert@fz-juelich.de>
Date: Fri, 26 Mar 2021 13:47:05 +0100
Subject: [PATCH] include joint extractor for vertical layers

---
 mlair/data_handler/data_handler_wrf_chem.py | 19 +++++++++++++++++--
 1 file changed, 17 insertions(+), 2 deletions(-)

diff --git a/mlair/data_handler/data_handler_wrf_chem.py b/mlair/data_handler/data_handler_wrf_chem.py
index 81a1458a..227a609b 100644
--- a/mlair/data_handler/data_handler_wrf_chem.py
+++ b/mlair/data_handler/data_handler_wrf_chem.py
@@ -45,8 +45,10 @@ class BaseWrfChemDataLoader:
     DEFAULT_STAGED_ROTATION_opts = dict(cen_lon=12., cen_lat=52.5,
                                         truelat1=30., truelat2=60.,
                                         stand_lon=12.)
-    DEFAULT_VARS_TO_ROTATE = ((('U', 'V'), ('Ull', 'Vll')), (('U10', 'V10'), ('U10ll', 'V10ll')),
-                              (('U', 'V'), ('wspdll', 'wdirll')), (('U10', 'V10'), ('wspd10ll', 'wdir10ll'))
+    DEFAULT_VARS_TO_ROTATE = ((('U', 'V'), ('Ull', 'Vll')),
+                              (('U10', 'V10'), ('U10ll', 'V10ll')),
+                              (('U', 'V'), ('wspdll', 'wdirll')),
+                              (('U10', 'V10'), ('wspd10ll', 'wdir10ll'))
                               )
 
     def __init__(self, data_path: str, common_file_starter: str = DEFAULT_FILE_STARTER,
@@ -384,6 +386,8 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation):
         self.targetvar_logical_z_coord_selector = self._ret_z_coord_select_if_valid(targetvar_logical_z_coord_selector,
                                                                                     as_input=False)
         self._logical_z_coord_name = None
+        self._joint_z_coord_selector = self._extract_largest_coord_extractor(self.var_logical_z_coord_selector,
+                                                                             self.targetvar_logical_z_coord_selector)
         super().__init__(*args, **kwargs)
 
     @staticmethod
@@ -431,10 +435,13 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation):
             self._logical_z_coord_name = sgc_loader.logical_z_coord_name
         # # select defined variables at grid box or grid coloumn based on nearest icoords
         data = self.extract_data_from_loader(sgc_loader, station)
+        if self._joint_z_coord_selector is not None:
+            data = data.sel({self._logical_z_coord_name: self._joint_z_coord_selector})
         # expand dimesion for iterdim
         data = data.expand_dims({self.iter_dim: station}).to_array(self.target_dim)
         # transpose dataarray: set first three fixed and keep remaining as is
         data = data.transpose(self.iter_dim, self.time_dim, self.target_dim, ...)
+
         data = dask.compute(self._slice_prep(data, start=start, end=end))[0]
         sgc_loader.data.close()
         # data = self.check_for_negative_concentrations(data)
@@ -445,6 +452,14 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation):
         meta = None
         return data, meta
 
+    @staticmethod
+    def _extract_largest_coord_extractor(var_extarctor, target_extractor) -> Union[List, None]:
+        if var_extarctor is not None and target_extractor is not None:
+            res = list(set(to_list(var_extarctor) + to_list(target_extractor)))
+        else:
+            res = None
+        return res
+
     def get_X(self, upsampling=False, as_numpy=False):
         if as_numpy is True:
             # return None
-- 
GitLab