From 1863bc190ae196bfb2bae77414d7682bc4c0f4fa Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Tue, 13 Jun 2023 15:07:46 +0200
Subject: [PATCH] downloaded data have additional window dimension if this dim
 is greater 1

---
 .../data_handler_single_station.py            |  5 +-
 mlair/helpers/data_sources/__init__.py        |  2 +-
 mlair/helpers/data_sources/data_loader.py     | 56 +++++++++++--------
 mlair/helpers/data_sources/era5.py            | 43 ++++++++------
 mlair/helpers/data_sources/ifs.py             | 13 +++--
 mlair/helpers/helpers.py                      |  3 +-
 mlair/run_modules/pre_processing.py           |  3 +-
 7 files changed, 76 insertions(+), 49 deletions(-)

diff --git a/mlair/data_handler/data_handler_single_station.py b/mlair/data_handler/data_handler_single_station.py
index 0be52e93..76f73bf4 100644
--- a/mlair/data_handler/data_handler_single_station.py
+++ b/mlair/data_handler/data_handler_single_station.py
@@ -337,8 +337,9 @@ class DataHandlerSingleStation(AbstractDataHandler):
             if os.path.exists(meta_file):
                 os.remove(meta_file)
             data, meta = data_sources.download_data(file_name, meta_file, station, statistics_per_var, sampling,
-                                            store_data_locally=store_data_locally, data_origin=data_origin,
-                                            time_dim=self.time_dim, target_dim=self.target_dim, iter_dim=self.iter_dim)
+                                                    store_data_locally=store_data_locally, data_origin=data_origin,
+                                                    time_dim=self.time_dim, target_dim=self.target_dim,
+                                                    iter_dim=self.iter_dim, window_dim=self.window_dim)
             logging.debug(f"loaded new data")
         else:
             try:
diff --git a/mlair/helpers/data_sources/__init__.py b/mlair/helpers/data_sources/__init__.py
index 21caa40e..34c70c5e 100644
--- a/mlair/helpers/data_sources/__init__.py
+++ b/mlair/helpers/data_sources/__init__.py
@@ -7,5 +7,5 @@ The module data_sources collects different data sources, namely ERA5, TOAR-Data
 __author__ = "Lukas Leufen"
 __date__ = "2023-06-01"
 
-from . import era5, join, toar_data, toar_data_v2, data_loader
+from . import era5, join, toar_data, toar_data_v2, data_loader, ifs
 from .data_loader import download_data
diff --git a/mlair/helpers/data_sources/data_loader.py b/mlair/helpers/data_sources/data_loader.py
index 8027e46d..c30568b8 100644
--- a/mlair/helpers/data_sources/data_loader.py
+++ b/mlair/helpers/data_sources/data_loader.py
@@ -15,11 +15,12 @@ import xarray as xr
 DEFAULT_TIME_DIM = "datetime"
 DEFAULT_TARGET_DIM = "variables"
 DEFAULT_ITER_DIM = "Stations"
+DEFAULT_WINDOW_DIM = "window"
 
 
 def download_data(file_name: str, meta_file: str, station, statistics_per_var, sampling,
                   store_data_locally=True, data_origin: Dict = None, time_dim=DEFAULT_TIME_DIM,
-                  target_dim=DEFAULT_TARGET_DIM, iter_dim=DEFAULT_ITER_DIM) -> [xr.DataArray, pd.DataFrame]:
+                  target_dim=DEFAULT_TARGET_DIM, iter_dim=DEFAULT_ITER_DIM, window_dim=DEFAULT_WINDOW_DIM) -> [xr.DataArray, pd.DataFrame]:
     """
     Download data from TOAR database using the JOIN interface or load local era5 data.
 
@@ -31,45 +32,54 @@ def download_data(file_name: str, meta_file: str, station, statistics_per_var, s
 
     :return: downloaded data and its meta data
     """
-    df_all = {}
-    df_era5_local, df_toar = None, None
-    meta_era5_local, meta_toar = None, None
+    df_era5_local, df_toar, df_ifs_local = None, None, None
+    meta_era5_local, meta_toar, meta_ifs_local = None, None, None
     if data_origin is not None:
         era5_local_origin = filter_dict_by_value(data_origin, "era5_local", True)
         era5_local_stats = select_from_dict(statistics_per_var, era5_local_origin.keys())
-        toar_origin = filter_dict_by_value(data_origin, "era5_local", False)
-        toar_stats = select_from_dict(statistics_per_var, era5_local_origin.keys(), filter_cond=False)
-        assert len(era5_local_origin) + len(toar_origin) == len(data_origin)
-        assert len(era5_local_stats) + len(toar_stats) == len(statistics_per_var)
+        ifs_local_origin = filter_dict_by_value(data_origin, "ifs", True)
+        ifs_local_stats = select_from_dict(statistics_per_var, ifs_local_origin.keys())
+        toar_origin = filter_dict_by_value(data_origin, ["era5_local", "ifs"], False)
+        toar_stats = select_from_dict(statistics_per_var, toar_origin.keys())
+        assert len(era5_local_origin) + len(toar_origin) + len(ifs_local_origin) == len(data_origin)
+        assert len(era5_local_stats) + len(toar_stats) + len(ifs_local_stats) == len(statistics_per_var)
     else:
-        era5_local_origin, toar_origin = None, None
-        era5_local_stats, toar_stats = statistics_per_var, statistics_per_var
+        era5_local_origin, toar_origin, ifs_local_origin = None, None, None
+        era5_local_stats, toar_stats, ifs_local_stats = statistics_per_var, statistics_per_var, statistics_per_var
 
     # load data
     if era5_local_origin is not None and len(era5_local_stats) > 0:
         # load era5 data
         df_era5_local, meta_era5_local = data_sources.era5.load_era5(
-            station_name=station, stat_var=era5_local_stats, sampling=sampling, data_origin=era5_local_origin)
-    if toar_origin is None or len(toar_stats) > 0:
+            station_name=station, stat_var=era5_local_stats, sampling=sampling, data_origin=era5_local_origin,
+            window_dim=window_dim, time_dim=time_dim)
+    if ifs_local_origin is not None and len(ifs_local_stats) > 0:
+        # load era5 data
+        df_ifs_local, meta_ifs5_local = data_sources.ifs.load_ifs(
+            station_name=station, stat_var=ifs_local_stats, sampling=sampling, data_origin=ifs_local_origin,
+            lead_time_dim=window_dim, initial_time_dim=time_dim)
+    if toar_origin is None or len(toar_stats) > 0:  # TODO return toar data as xarray
         # load combined data from toar-data (v2 & v1)
         df_toar, meta_toar = data_sources.toar_data.download_toar(station=station, toar_stats=toar_stats,
                                                                   sampling=sampling, data_origin=toar_origin)
 
-    if df_era5_local is None and df_toar is None:
-        raise EmptyQueryResult(f"No data available for era5_local and toar-data")
-
-    df = pd.concat([df_era5_local, df_toar], axis=1, sort=True)
-    if meta_era5_local is not None and meta_toar is not None:
-        meta = meta_era5_local.combine_first(meta_toar)
+    valid_df = [e for e in [df_era5_local, df_toar, df_ifs_local] if e is not None]
+    if len(valid_df) == 0:
+        raise EmptyQueryResult(f"No data available for era5_local, toar-data and ifs_local")
+    df = xr.concat(valid_df, dim=time_dim)
+    valid_meta = [e for e in [meta_era5_local, meta_toar, meta_ifs_local] if e is not None]
+    if len(valid_meta) > 0:
+        meta = valid_meta[0]
+        for e in valid_meta[1:]:
+            meta = meta.combine_first(e)
     else:
-        meta = meta_era5_local if meta_era5_local is not None else meta_toar
+        meta = None
     meta.loc["data_origin"] = str(data_origin)
     meta.loc["statistics_per_var"] = str(statistics_per_var)
 
-    df_all[station[0]] = df
-    # convert df_all to xarray
-    xarr = {k: xr.DataArray(v, dims=[time_dim, target_dim]) for k, v in df_all.items()}
-    xarr = xr.Dataset(xarr).to_array(dim=iter_dim)
+    xarr = df.expand_dims({iter_dim: station})
+    if len(xarr.coords[window_dim]) <= 1:  # keep window dim only if there is more than a single entry
+        xarr = xarr.squeeze(window_dim, drop=True)
     if store_data_locally is True:
         # save locally as nc/csv file
         xarr.to_netcdf(path=file_name)
diff --git a/mlair/helpers/data_sources/era5.py b/mlair/helpers/data_sources/era5.py
index 3e81a460..66117106 100644
--- a/mlair/helpers/data_sources/era5.py
+++ b/mlair/helpers/data_sources/era5.py
@@ -4,6 +4,7 @@ __date__ = "2022-06-09"
 
 import logging
 import os
+from functools import partial
 
 import pandas as pd
 import xarray as xr
@@ -16,7 +17,7 @@ from mlair.helpers.data_sources.data_loader import EmptyQueryResult
 from mlair.helpers.meteo import relative_humidity_from_dewpoint
 
 
-def load_era5(station_name, stat_var, sampling, data_origin):
+def load_era5(station_name, stat_var, sampling, data_origin, time_dim, window_dim):
 
     # make sure station_name parameter is a list
     station_name = helpers.to_list(station_name)
@@ -37,43 +38,53 @@ def load_era5(station_name, stat_var, sampling, data_origin):
     # sel data for station using sel method nearest
     logging.info(f"load data for {station_meta['codes'][0]} from ERA5")
     try:
-        with xr.open_mfdataset(os.path.join(data_path, file_names)) as data:
-            lon, lat = station_meta["coordinates"]["lng"],  station_meta["coordinates"]["lat"]
-            station_dask = data.sel(lon=lon, lat=lat, method="nearest", drop=True)
-            station_data = station_dask.to_array().T.compute()
+        lon, lat = station_meta["coordinates"]["lng"], station_meta["coordinates"]["lat"]
+        file_names = os.path.join(data_path, file_names)
+        with xr.open_mfdataset(file_names, preprocess=partial(preprocess_era5_single_file, lon, lat)) as data:
+            station_data = data.to_array().T.compute()
     except OSError as e:
         logging.info(f"Cannot load era5 data from path {data_path} and filenames {file_names} due to: {e}")
         return None, None
 
-    # transform data and meta to pandas
-    station_data = station_data.to_pandas()
     if "relhum" in stat_var:
-        station_data["RHw"] = relative_humidity_from_dewpoint(station_data["D2M"], station_data["T2M"])
-    station_data.columns = _rename_era5_variables(station_data.columns)
+        relhum = relative_humidity_from_dewpoint(station_data.sel(variable="D2M"), station_data.sel(variable="T2M"))
+        station_data = xr.concat([station_data, relhum.expand_dims({"variable": ["RHw"]})], dim="variable")
+    station_data.coords["variable"] = _rename_era5_variables(station_data.coords["variable"].values)
 
     # check if all requested variables are available
-    if set(stat_var).issubset(station_data.columns) is False:
-        missing_variables = set(stat_var).difference(station_data.columns)
+    if set(stat_var).issubset(station_data.coords["variable"].values) is False:
+        missing_variables = set(stat_var).difference(station_data.coords["variable"].values)
         origin = helpers.select_from_dict(data_origin, missing_variables)
         options = f"station={station_name}, origin={origin}"
         raise EmptyQueryResult(f"No data found for variables {missing_variables} and options {options} in JOIN.")
     else:
-        station_data = station_data[stat_var]
+        station_data = station_data.sel(variable=list(stat_var.keys()))
 
     # convert to local timezone
-    station_data = correct_timezone(station_data, station_meta, sampling)
+    station_data.coords["time"] = correct_timezone(station_data.to_pandas(), station_meta, sampling).index
+    station_data = station_data.rename({"time": time_dim})
 
-    variable_meta = _emulate_meta_data(station_data)
+    # expand window_dim
+    station_data = station_data.expand_dims({window_dim: [0]})
+
+    # create meta data
+    variable_meta = _emulate_meta_data(station_data.coords["variable"].values)
     meta = combine_meta_data(station_meta, variable_meta)
     meta = pd.DataFrame.from_dict(meta, orient='index')
     meta.columns = station_name
     return station_data, meta
 
 
-def _emulate_meta_data(station_data):
+def preprocess_era5_single_file(lon, lat, ds):
+    """Select lon and lat from data file and transform valid time into lead time."""
+    ds = ds.sel(lon=lon, lat=lat, method="nearest", drop=True)
+    return ds
+
+
+def _emulate_meta_data(variables):
     general_meta = {"sampling_frequency": "hourly", "data_origin": "model", "data_origin_type": "model"}
     roles_meta = {"roles": [{"contact": {"organisation": {"name": "ERA5", "longname": "ECMWF"}}}]}
-    variable_meta = {var: {"variable": {"name": var}, **roles_meta, ** general_meta} for var in station_data.columns}
+    variable_meta = {var: {"variable": {"name": var}, **roles_meta, ** general_meta} for var in variables}
     return variable_meta
 
 
diff --git a/mlair/helpers/data_sources/ifs.py b/mlair/helpers/data_sources/ifs.py
index e9e75bd4..27faab3e 100644
--- a/mlair/helpers/data_sources/ifs.py
+++ b/mlair/helpers/data_sources/ifs.py
@@ -20,7 +20,7 @@ from mlair.helpers.data_sources.data_loader import EmptyQueryResult
 from mlair.helpers.meteo import relative_humidity_from_dewpoint
 
 
-def load_ifs(station_name, stat_var, sampling, data_origin):
+def load_ifs(station_name, stat_var, sampling, data_origin, lead_time_dim, initial_time_dim):
 
     # make sure station_name parameter is a list
     station_name = helpers.to_list(station_name)
@@ -65,10 +65,13 @@ def load_ifs(station_name, stat_var, sampling, data_origin):
         station_data = station_data.sel(variable=list(stat_var.keys()))
 
     # convert to local timezone
-    station_data.coords["initial_time"] = correct_timezone(station_data.sel(lead_time=1).to_pandas(), station_meta,
+    station_data.coords["initial_time"] = correct_timezone(station_data.sel(lead_time=0).to_pandas(), station_meta,
                                                            sampling).index
 
-    variable_meta = _emulate_meta_data(station_data)
+    # rename lead time and initial time to MLAir's internal dimension names
+    station_data = station_data.rename({"lead_time": lead_time_dim, "initial_time": initial_time_dim})
+
+    variable_meta = _emulate_meta_data(station_data.coords["variable"].values)
     meta = combine_meta_data(station_meta, variable_meta)
     meta = pd.DataFrame.from_dict(meta, orient='index')
     meta.columns = station_name
@@ -101,10 +104,10 @@ def expand_dims_initial_time(ds):
     return ds
 
 
-def _emulate_meta_data(station_data):
+def _emulate_meta_data(variables):
     general_meta = {"sampling_frequency": "hourly", "data_origin": "model", "data_origin_type": "model"}
     roles_meta = {"roles": [{"contact": {"organisation": {"name": "IFS", "longname": "ECMWF"}}}]}
-    variable_meta = {var: {"variable": {"name": var}, **roles_meta, ** general_meta} for var in station_data.coords["variable"].values}
+    variable_meta = {var: {"variable": {"name": var}, **roles_meta, ** general_meta} for var in variables}
     return variable_meta
 
 
diff --git a/mlair/helpers/helpers.py b/mlair/helpers/helpers.py
index 0b97f826..4ebea8a0 100644
--- a/mlair/helpers/helpers.py
+++ b/mlair/helpers/helpers.py
@@ -283,7 +283,8 @@ def filter_dict_by_value(dictionary: dict, filter_val: Any, filter_cond: bool) -
         do not match the criteria (if `False`)
     :returns: a filtered dict with either matching or non-matching elements depending on the `filter_cond`
     """
-    return dict(filter(lambda x: (x[1] == filter_val) is filter_cond, dictionary.items()))
+    filter_val = to_list(filter_val)
+    return dict(filter(lambda x: (x[1] in filter_val) is filter_cond, dictionary.items()))
 
 
 def str2bool(v):
diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py
index 5710b633..d56d064a 100644
--- a/mlair/run_modules/pre_processing.py
+++ b/mlair/run_modules/pre_processing.py
@@ -295,11 +295,12 @@ class PreProcessing(RunEnvironment):
         else:  # serial solution
             logging.info("use serial validate station approach")
             kwargs.update({"return_strategy": "result"})
-            for station in set_stations:
+            for i, station in enumerate(set_stations):
                 dh, s = f_proc(data_handler, station, set_name, store_processed_data, **kwargs)
                 if dh is not None:
                     collection.add(dh)
                     valid_stations.append(s)
+                logging.info(f"...finished: {s} ({int((i + 1.) / len(set_stations) * 100)}%)")
 
         logging.info(f"run for {t_outer} to check {len(set_stations)} station(s). Found {len(collection)}/"
                      f"{len(set_stations)} valid stations ({set_name}).")
-- 
GitLab