From 24df77d8f4c2bc2cb618ed4ae5be49f481f0e48c Mon Sep 17 00:00:00 2001
From: Felix Kleinert <f.kleinert@fz-juelich.de>
Date: Thu, 20 May 2021 13:26:00 +0200
Subject: [PATCH] add some documentation

---
 mlair/data_handler/data_handler_wrf_chem.py | 225 ++++++++++++++++----
 1 file changed, 187 insertions(+), 38 deletions(-)

diff --git a/mlair/data_handler/data_handler_wrf_chem.py b/mlair/data_handler/data_handler_wrf_chem.py
index 763a264e..9688a413 100644
--- a/mlair/data_handler/data_handler_wrf_chem.py
+++ b/mlair/data_handler/data_handler_wrf_chem.py
@@ -29,6 +29,9 @@ int_or_list_of_int = Union[int, List[int]]
 
 
 class BaseWrfChemDataLoader:
+    """
+    Base clase to load WRF-Chem data.
+    """
     DEFAULT_LOGICAL_TIME_COORD_NAME = 'Time'
     DEFAULT_LOGICAL_X_COORD_NAME = 'west_east'
     DEFAULT_LOGICAL_Y_COORD_NAME = 'south_north'
@@ -60,7 +63,6 @@ class BaseWrfChemDataLoader:
                  physical_x_coord_name: str = DEFAULT_PHYSICAL_X_COORD_NAME,
                  physical_y_coord_name: str = DEFAULT_PHYSICAL_Y_COORD_NAME,
                  physical_t_coord_name: str = DEFAULT_PHYSICAL_TIME_COORD_NAME,
-                 staged_vars: List[str] = DEFAULT_STAGED_VARS,
                  variables=None, z_coord_selector=None,
                  start_time=None, end_time=None,
                  staged_rotation_opts: Dict = DEFAULT_STAGED_ROTATION_opts,
@@ -69,6 +71,33 @@ class BaseWrfChemDataLoader:
                  date_format_of_nc_file=None,
 
                  ):
+        """
+        Initialisze data loader
+
+        :param data_path: path to WRF-Chem data
+        :param common_file_starter: beginning of each idividual file
+        :param time_dim_name: Name of time diension in data file
+        :param rechunk_values: Mapping for chunk sizes
+        :param logical_x_coord_name: Logical x coord name (e.g. west-east direction)
+        :param logical_y_coord_name: Logical y coord name (e.g. south-north direction)
+        :param logical_z_coord_name: Logical z coord name (e.g. bottom-top of atmosphere direction)
+        :param physical_x_coord_name: Physical x coord name (e.g. XLONG, might be 2D)
+        :param physical_y_coord_name: Physical y coord name (e.g. XLAT, might be 2D)
+        :param physical_t_coord_name: Physical time coord name
+        :param variables: variables to read from files
+        :param z_coord_selector: physical z coords to read from file
+        :param start_time: first physical time coord to read from file
+        :param end_time: end physical time coord to read from file
+        :param staged_rotation_opts: mapping of options needed for vector transformation
+        :param vars_to_rotate: name of variables which should be rotated from one coord system to another
+                (e.g.  ((('U', 'V'), ('Ull', 'Vll')), (('U10', 'V10'), ('U10ll', 'V10ll'))) to rotate U and V
+                in coord sys A to Ull and Vll in coord sys B (e.g. geogr)).
+        :param staged_dimension_mapping: mapping of staged dimensions and names
+        :type staged_dimension_mapping:
+        :param stag_ending: ending of staged variables identifier (most likely provided in the metadata of the WRF-Chem
+        file)
+        :param date_format_of_nc_file: date format of input  file names (e.g %Y-%m-%d)
+        """
         # super().__init__()
         self.data_path = data_path
         self.common_file_starter = common_file_starter
@@ -94,10 +123,6 @@ class BaseWrfChemDataLoader:
             date_format_of_nc_file = "%Y-%m-%d"
         self.date_format_of_nc_file = date_format_of_nc_file
 
-        # if rechunk_values is None:
-        #     self.rechunk_values = {self.time_dim_name: 1}
-        # else:
-        #     self.rechunk_values = rechunk_values
         self.rechunk_values = rechunk_values
 
         self._stag_ending = stag_ending
@@ -125,10 +150,20 @@ class BaseWrfChemDataLoader:
 
     @property
     def data(self) -> xr.Dataset:
+        """
+        Returns the actual data of the WRF-Chem files
+        :return:
+        :rtype:
+        """
         return self._data
 
     @property
     def dataset_search_str(self):
+        """
+        Returns a string to search for existing WRF-Chem data
+        :return:
+        :rtype:
+        """
         if (self.start_time is None) and (self.end_time is None):
             path_list = os.path.join(self.data_path, self.common_file_starter + '*')
             logging.info(f"Reading file(s): {path_list}")
@@ -148,9 +183,11 @@ class BaseWrfChemDataLoader:
 
     @TimeTrackingWrapper
     def open_data(self):
-        # see also https://github.com/pydata/xarray/issues/1385#issuecomment-438870575
-        # data = xr.open_mfdataset(paths=self.dataset_search_str, combine='nested', concat_dim=self.time_dim_name,
-        #                          parallel=True, decode_cf=False)
+        """
+        Opens WRF-Chem data as dataset
+        :return:
+        :rtype:
+        """
         if self.variables is None:
             # see also https://github.com/pydata/xarray/issues/1385#issuecomment-438870575
             data = xr.open_mfdataset(paths=self.dataset_search_str, combine='nested', concat_dim=self.time_dim_name,
@@ -159,27 +196,13 @@ class BaseWrfChemDataLoader:
             data = xr.open_mfdataset(paths=self.dataset_search_str, combine='nested', concat_dim=self.time_dim_name,
                                      parallel=True, decode_cf=False, preprocess=self.preprocess_fkt_for_loader, )
         data = xr.decode_cf(data)
-        # if self.variables is not None:
-        #     data = self.preprocess_fkt_for_loader(data)
-
-        # if self.rechunk_values is None:
-        #     chunks = {k: 'auto' for k in data.chunks.keys() }
-        #     chunks[self.time_dim_name] = -1
-        #     data = data.chunk(chunks)
-        #     # data = data.chunk("auto")
-        # else:
-        #     data = data.chunk(self.rechunk_values)
         self._data = data
 
     def preprocess_fkt_for_loader(self, ds):
         # ToDo make genreal: Currently it's in fact hardcoded!
-        wind_var_mapping = {'Ull': ['U'], 'Vll': ['V'], 'U10ll': ['U10'], 'V10ll': ['V10'],
-                            'wspdll': ['U', 'V'], 'wdirll': ['U', 'V'],
-                            'wspd10ll': ['U10', 'V10'], 'wdir10ll': ['U10', 'V10'],
-                            }
         potential_wind_vars_list = list(
             set(itertools.chain(
-                *itertools.chain(*SingleGridColumnWrfChemDataLoader.DEFAULT_VARS_TO_ROTATE))))
+                *itertools.chain(*self.vars_to_rotate))))
         none_wind_vars_to_keep = [x for x in self.variables if x not in potential_wind_vars_list]
         wind_vars_to_keep = ['U', 'V', 'U10', 'V10']
         combined_vars_to_keep = none_wind_vars_to_keep + wind_vars_to_keep
@@ -216,10 +239,25 @@ class BaseWrfChemDataLoader:
             print(f'{i}: {filenamestr} {status}')
 
     def get_distances(self, lat, lon):
+        """
+        Calculates distances from given lat, lon to all grid boxes of model domain.
+        :param lat: latitude
+        :param lon: longitude
+        :return:
+        :rtype:
+        """
         dist = haversine_dist(lat1=self._data.XLAT, lon1=self._data.XLONG, lat2=lat, lon2=lon)
         return dist
 
     def get_bearing(self, lat, lon, points_last=True):
+        """
+        Calculates bearings from a given lat, lon to all grid boxes of model domain.
+        :param lat: latitude
+        :param lon:  longitude
+        :param points_last: return transpose with points as last dimensions (if True, default)
+        :return:
+        :rtype:
+        """
         bearing = bearing_angle(lat1=lat, lon1=lon, lat2=self._data.XLAT, lon2=self._data.XLONG)
         if points_last:
             return bearing.transpose(..., 'points')
@@ -227,6 +265,15 @@ class BaseWrfChemDataLoader:
             return bearing
 
     def compute_nearest_icoordinates(self, lat, lon, dim=None):
+        """
+        Calculate nearest logical coordinate (grid box) for a given location (lat, lon)
+
+        :param lat: latitude
+        :param lon: longitude
+        :param dim: dimension to apply argmin function
+        :return:
+        :rtype:
+        """
         dist = self.get_distances(lat=lat, lon=lon)
 
         if dim is None:
@@ -235,6 +282,11 @@ class BaseWrfChemDataLoader:
             return dist.argmin(dim)
 
     def _set_dims_as_coords(self):
+        """
+        Set dimensions as coordinates
+        :return:
+        :rtype:
+        """
         if self._data is None:
             raise IOError(f'{self.__class__.__name__} can not set dims as coords. Use must use `open_data()` before.')
         data = self._data
@@ -244,6 +296,15 @@ class BaseWrfChemDataLoader:
         # logging.info('set dimensions as coordinates')
 
     def apply_staged_transormation(self, mapping_of_stag2unstag=None):
+        """
+        Apply vector transformation on staged variables
+
+        :param mapping_of_stag2unstag:mapping from staged field as key to unstaged field as value. Note that U10 and V10
+        in WRF-Chem are located at grid box centers while U and V are located on edges.
+        :type mapping_of_stag2unstag:
+        :return:
+        :rtype:
+        """
         if mapping_of_stag2unstag is None:
             mapping_of_stag2unstag = {'U': 'U10', 'V': 'V10', 'U10': 'U10', 'V10': 'V10'}
         vectorrot = VectorRotateLambertConformal2latlon(xlat=self._data[self.physical_y_coord_name].data,
@@ -291,17 +352,30 @@ class BaseWrfChemDataLoader:
 
     def set_interpolated_field(self, staged_field: xr.DataArray, target_field: xr.DataArray,
                                dropped_staged_attrs: List[str] =None, **kwargs):
+        """
+        Interpolates a staged field (variables on edges of grid box) to a given target field (center of grid box)
+
+        :param staged_field: field of staged variables
+        :param target_field: target field to interpolate on
+        :param dropped_staged_attrs: attribute which indicates that field is staged and therefore should be removed
+        as attribute for the target field
+        :param kwargs:
+        :return: Interpolated field
+        :rtype:
+        """
         stagger_attr_name = kwargs.pop('stagger_attr_name', 'stagger')
         stagger = kwargs.pop('stagger', staged_field.attrs[stagger_attr_name])
         if dropped_staged_attrs is None:
             dropped_staged_attrs = [stagger_attr_name]
 
+        # prepare new (interpolated) field by setting metainfo etc
         excluded_dims = [self.staged_dimension_mapping[stagger]]
         target_dim_order = [dim.replace('_stag', '') for dim in staged_field.dims]
         new_field = xr.broadcast(target_field, staged_field, exclude=excluded_dims)[0]
         new_field = new_field.transpose(*target_dim_order)
         new_field.attrs.update(remove_items(staged_field.attrs, dropped_staged_attrs))
 
+        # interpolate from staged to unstaged grid
         new_field.data = VectorRotateLambertConformal2latlon.interpolate_to_grid_center(staged_field,
                                                                                         self.staged_dimension_mapping[stagger])
         return new_field
@@ -326,6 +400,26 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader):
                  window_lead_time=DEFAULT_WINDOW_LEAD_TIME,
                  external_coords_file: str = None,
                  wind_sectors=None, **kwargs):
+        """
+
+
+        :param coords:
+        :type coords:
+        :param target_dim:
+        :type target_dim:
+        :param target_var:
+        :type target_var:
+        :param window_history_size:
+        :type window_history_size:
+        :param window_lead_time:
+        :type window_lead_time:
+        :param external_coords_file:
+        :type external_coords_file:
+        :param wind_sectors:
+        :type wind_sectors:
+        :param kwargs:
+        :type kwargs:
+        """
         super().__init__(**kwargs)
         self._set_coords(coords)
         self.target_dim = target_dim
@@ -340,6 +434,12 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader):
         logging.debug("SingleGridColumnWrfChemDataLoader Initialised")
 
     def __enter__(self):
+        """
+        Enter method which opens WRF-Chem data, sets coordinates (potentially from external coord file)
+
+        :return:
+        :rtype:
+        """
         self.open_data()
 
         if self.physical_t_coord_name != self.time_dim_name:
@@ -359,6 +459,12 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader):
         gc.collect()
 
     def _set_geoinfos(self):
+        """
+        Sets geoinfos like nearest logical and physical coordinates, distances and bearings to/from a point of interest.
+
+        :return:
+        :rtype:
+        """
         # identify nearest coords
         self._set_nearest_icoords(dim=[self.logical_x_coord_name, self.logical_y_coord_name])
         self._set_nearest_coords()
@@ -385,21 +491,26 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader):
         # expand dataset by new dataarray
         self._geo_infos['wind_sectors'] = xr.full_like(bearing, fill_value=np.nan) # xr.DataArray(coords=bearing.coords, dims=bearing.dims)
         for i, (k, v) in enumerate(ws.wind_sectore_edges.items()):
-            self._geo_infos['wind_sectors'] = xr.where(ws.is_in_sector(k, self.geo_infos['bearing']), i, self._geo_infos['wind_sectors'])
-        # self._geo_infos['wind_sectors'].attrs.update(dict(units='Wind Sector'))
+            self._geo_infos['wind_sectors'] = xr.where(
+                ws.is_in_sector(k, self.geo_infos['bearing']), i, self._geo_infos['wind_sectors'])
 
     @property
-    def geo_infos(self):
+    def geo_infos(self) -> xr.Dataset:
         return self._geo_infos
 
     def _apply_external_coordinates(self):
+        """
+        Apply external coordinates to data
+        :return:
+        :rtype:
+        """
         ds_coords = xr.open_dataset(self.external_coords_file, chunks={'south_north': 36, 'west_east': 40})
         data = self._data
         for k, v in ds_coords.coords.variables.items():
             data = data.assign_coords({k: (remove_items(list(v.dims), 'Time'), v.values.squeeze())})
         self._data = data
         ds_coords.close()
-        # logging.debug('setup external coords')
+        logging.debug('setup external coords')
 
     def _set_coords(self, coords):
         __set_coords = dict(lat=None, lon=None)
@@ -409,7 +520,7 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader):
             self.__coords = dict(lat=coords[0], lon=coords[1])
             logging.debug(f"{self.__class__.__name__}.__coords is set to {self.__coords}")
         elif isinstance(coords, dict):
-            if (coords.keys() == __set_coords.keys()):
+            if coords.keys() == __set_coords.keys():
                 self.__coords = coords
             else:
                 raise KeyError(f"dict `coords' must have the keys `lat', and `lon' but has: {coords.keys()}")
@@ -417,18 +528,38 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader):
             raise TypeError(f"`coords' must be a tuple of floats or a dict, but is of type: {type(coords)}")
 
     def get_coordinates(self, as_arrays=False) -> Union[Tuple[np.ndarray, np.ndarray], dict]:
+        """
+        Retuns physical coordinates
+        :param as_arrays: switch to return coords as array (True, lat/lon) or dictionary (False, default, lat/lon as keys)
+        :return:
+        :rtype:
+        """
         if as_arrays:
             return np.array(self.__coords['lat']), np.array(self.__coords['lon'])
         else:
             return self.__coords
 
     def _set_nearest_icoords(self, dim=None):
+        """
+        Set nearest logical coordinates (icoords). Coordinates of gridbox center.
+        :param dim:
+        :type dim:
+        :return:
+        :rtype:
+        """
         lat, lon = self.get_coordinates(as_arrays=True)
         with ProgressBar():
             logging.info("SingleGridColumnWrfChemDataLoader: compute nearest icoordinates")
             self._nearest_icoords = dask.compute(self.compute_nearest_icoordinates(lat, lon, dim))[0]
 
     def get_nearest_icoords(self, as_arrays=False):
+        """
+        Get nearest logical coordinates (icoords). Coordinates of gridbox center.
+        :param as_arrays:
+        :type as_arrays:
+        :return:
+        :rtype:
+        """
         if as_arrays:
             return (self._nearest_icoords[self.logical_y_coord_name].values,
                     self._nearest_icoords[self.logical_x_coord_name].values,
@@ -437,6 +568,11 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader):
             return {k: list(v.values) for k, v in self._nearest_icoords.items()}
 
     def _set_nearest_coords(self):
+        """
+        Set nearest physical coordinates (coords). Coordinates of gridbox center.
+        :return:
+        :rtype:
+        """
         icoords = self.get_nearest_icoords()
         ilat = convert2xrda(np.array(icoords[self.logical_y_coord_name]), use_1d_default=True)
         ilon = convert2xrda(np.array(icoords[self.logical_x_coord_name]), use_1d_default=True)
@@ -447,6 +583,14 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader):
         self._nearest_coords = dict(lat=lat, lon=lon)
 
     def get_nearest_coords(self, as_arrays=False):
+        """
+        Set nearest physical coordinates (coords). Coordinates of gridbox center.
+
+        :param as_arrays:
+        :type as_arrays:
+        :return:
+        :rtype:
+        """
         if as_arrays:
             return (self._nearest_coords['lat'].values,
                     self._nearest_coords['lon'].values,
@@ -466,10 +610,10 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation):
                  as_image_like_data_format=True,
                  **kwargs):
         self.external_coords_file = external_coords_file
-        self.var_logical_z_coord_selector = self._ret_z_coord_select_if_valid(var_logical_z_coord_selector,
-                                                                              as_input=True)
-        self.targetvar_logical_z_coord_selector = self._ret_z_coord_select_if_valid(targetvar_logical_z_coord_selector,
-                                                                                    as_input=False)
+        self.var_logical_z_coord_selector = self._return_z_coord_select_if_valid(var_logical_z_coord_selector,
+                                                                                 as_input=True)
+        self.targetvar_logical_z_coord_selector = self._return_z_coord_select_if_valid(targetvar_logical_z_coord_selector,
+                                                                                       as_input=False)
         self._logical_z_coord_name = None
         self._joint_z_coord_selector = self._extract_largest_coord_extractor(self.var_logical_z_coord_selector,
                                                                              self.targetvar_logical_z_coord_selector)
@@ -480,7 +624,7 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation):
         super().__init__(*args, **kwargs)
 
     @staticmethod
-    def _ret_z_coord_select_if_valid(z_coord: int_or_list_of_int, as_input: bool) -> int_or_list_of_int:
+    def _return_z_coord_select_if_valid(z_coord: int_or_list_of_int, as_input: bool) -> int_or_list_of_int:
         if isinstance(z_coord, int):
             return z_coord
         elif isinstance(z_coord, list):
@@ -496,6 +640,15 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation):
 
     @staticmethod
     def coord_str2coords(str_coords: str, sep='__', dec_marker='_') -> Tuple[float_np_xr, float_np_xr]:
+        """
+        Converts string of station names including lat lon information to numpy array
+        :param str_coords: Name of point of interest (pseudo station) e.g. 'coords__50_7536__7_0827'
+        :param sep: separator between 'name', 'lat', and 'lon' within the coordinate string (e.g.default  '__')
+        :param dec_marker: decimal marker for lat and lon values (e.g default '_'
+        :type dec_marker:
+        :return:
+        :rtype:
+        """
         if isinstance(str_coords, list) and len(str_coords) == 1:
             str_coords = str_coords[0]
         _, lat, lon = str_coords.split(sep=sep)
@@ -505,6 +658,7 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation):
         return np.array(float(lat)), np.array(float(lon))
 
     def setup_data_path(self, data_path: str, sampling: str):
+        # ToDo What is this good for?!?!?!
         return data_path
 
     @TimeTrackingWrapper
@@ -621,11 +775,6 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation):
         self.observation = self.modify_observation()
 
         self.remove_nan(self.time_dim)
-        # self.input_data = self.input_data.compute()
-        # self.label = self.label.compute()
-        # self.observation = self.observation.compute()
-        # self.target_data = self.target_data.compute()
-        # self._data.close()
 
     def modify_history(self):
         """
-- 
GitLab