From c9b110939379db860c8fb89e5bac1e875a6abb76 Mon Sep 17 00:00:00 2001
From: Felix Kleinert <f.kleinert@fz-juelich.de>
Date: Wed, 11 Aug 2021 14:25:18 +0200
Subject: [PATCH] include first impl of toarstatistics

---
 .gitmodules                                 |  3 +
 mlair/data_handler/data_handler_wrf_chem.py | 65 +++++++++++++++++++--
 mlair/external/toarstats                    |  1 +
 requirements.txt                            |  1 +
 requirements_gpu.txt                        |  1 +
 run_wrf_dh.py                               | 45 +++++++-------
 run_wrf_dh_sector.py                        |  2 +-
 7 files changed, 93 insertions(+), 25 deletions(-)
 create mode 100644 .gitmodules
 create mode 160000 mlair/external/toarstats

diff --git a/.gitmodules b/.gitmodules
new file mode 100644
index 00000000..45f0660c
--- /dev/null
+++ b/.gitmodules
@@ -0,0 +1,3 @@
+[submodule "mlair/external/toarstats"]
+	path = mlair/external/toarstats
+	url = git@gitlab.jsc.fz-juelich.de:esde/toar-public/toarstats.git
diff --git a/mlair/data_handler/data_handler_wrf_chem.py b/mlair/data_handler/data_handler_wrf_chem.py
index 8188ddf3..4aedd310 100644
--- a/mlair/data_handler/data_handler_wrf_chem.py
+++ b/mlair/data_handler/data_handler_wrf_chem.py
@@ -8,6 +8,8 @@ import itertools
 import glob
 import matplotlib.pyplot as plt
 from dask.diagnostics import ProgressBar
+from tzwhere import tzwhere
+from toarstats import toarstats
 
 import dask
 import inspect
@@ -609,6 +611,7 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation):
                  rechunk_values=None,
                  date_format_of_nc_file=None,
                  as_image_like_data_format=True,
+                 time_zone=None,
                  **kwargs):
         self.external_coords_file = external_coords_file
         self.var_logical_z_coord_selector = self._return_z_coord_select_if_valid(var_logical_z_coord_selector,
@@ -622,6 +625,9 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation):
         self.rechunk_values = rechunk_values
         self.date_format_of_nc_file = date_format_of_nc_file
         self.as_image_like_data_format = as_image_like_data_format
+
+        self.time_zone = time_zone
+        self.tzwhere = tzwhere.tzwhere()
         super().__init__(*args, **kwargs)
 
     @staticmethod
@@ -722,18 +728,69 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation):
                 logging.info(f"start compute data for {self.station} in load_data")
                 data = dask.compute(data)[0]
 
+        # set time zone to attrs
+        data.attrs["time_zone"] = self.time_zone
+
         # expand dimesion for iterdim
         data = data.expand_dims({self.iter_dim: station}).to_array(self.target_dim)
         # transpose dataarray: set first three fixed and keep remaining as is
         data = data.transpose(self.iter_dim, self.time_dim, self.target_dim, ...)
-        data = self._slice_prep(data, start=start, end=end)
+
+
         # ToDo add metadata
-        _meta = {'station_lon': self.loader.get_coordinates()['lon'].tolist(),
-                 'station_lat': self.loader.get_coordinates()['lat'].tolist()}
-        meta = pd.DataFrame(_meta, index=station).T
+        _meta = {self.loader.physical_x_coord_name: self.loader.get_coordinates()['lon'].tolist(),
+                 self.loader.physical_y_coord_name: self.loader.get_coordinates()['lat'].tolist()}
+        meta = pd.DataFrame(_meta, index=station)
+
+        if isinstance(self.sampling, tuple) and len(self.sampling) == 2:
+            if self.var_logical_z_coord_selector != 0:
+                raise NotImplementedError(
+                    f"Method `apply_toarstats` is not implemented for var_logical_z_coord_selector != 0: "
+                    f"Is {self.var_logical_z_coord_selector}")
+            data = self.apply_toarstats(data, meta, self.sampling[1])
+        data = self._slice_prep(data, start=start, end=end)
 
         return data, meta
 
+    def apply_toarstats(self, data, meta, target_sampling="daily"):
+        local_time_zone = self.tzwhere.tzNameAt(latitude=meta[self.loader.physical_y_coord_name],
+                                                longitude=meta[self.loader.physical_x_coord_name])
+        hdata = data.squeeze().to_pandas()
+        hdata.index = self.set_time_zone(hdata.index)
+        hdata.index = hdata.index.tz_convert(local_time_zone)
+        hsampling_data = []
+        for i, var in enumerate(hdata.columns):
+            hdf = toarstats(target_sampling, self.statistics_per_var[var], hdata[var],
+                            (meta[self.loader.physical_y_coord_name], meta[self.loader.physical_x_coord_name]))
+            # if i == 0:
+            hsampling_data.append(xr.DataArray(hdf, coords=[hdf.index, [var]], dims=[self.loader.physical_t_coord_name, self.target_dim]))
+            # else:
+            #     df[var] = hdf
+        sampling_data = xr.concat(hsampling_data, dim=self.target_dim)
+        sampling_data = sampling_data.broadcast_like(data, exclude=self.loader.physical_t_coord_name).dropna(
+            self.loader.physical_t_coord_name)
+        sampling_data.attrs = data.attrs
+        missing_squeezed_coords = data.coords._names - sampling_data.coords._names
+        for coord in missing_squeezed_coords:
+            sampling_data.coords[coord] = data.coords[coord]
+
+        return self. sampling_data
+
+    def set_time_zone(self, time_index):
+        """
+        Sets time zone information on a given index
+
+        :param time_index:
+        :type time_index:
+        :return:
+        :rtype:
+        """
+
+        dti = pd.to_datetime(time_index)
+        dti = dti.tz_localize(self.time_zone)
+        logging.info(f"Set external time zone for {self.station} to: {self.time_zone}")
+        return dti
+
     @staticmethod
     def _extract_largest_coord_extractor(var_extarctor, target_extractor) -> Union[List, None]:
         if var_extarctor is not None and target_extractor is not None:
diff --git a/mlair/external/toarstats b/mlair/external/toarstats
new file mode 160000
index 00000000..a3c20206
--- /dev/null
+++ b/mlair/external/toarstats
@@ -0,0 +1 @@
+Subproject commit a3c2020686d300db6ca5a1d0b81128f4740e2b16
diff --git a/requirements.txt b/requirements.txt
index dba565fb..a2ccc2b0 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -64,6 +64,7 @@ termcolor==1.1.0
 toml==0.10.2
 toolz==0.11.1
 typing-extensions==3.7.4.3
+tzwhere==3.0.3
 urllib3==1.26.3
 Werkzeug==1.0.1
 wget==3.2
diff --git a/requirements_gpu.txt b/requirements_gpu.txt
index f170e1b7..be648ff8 100644
--- a/requirements_gpu.txt
+++ b/requirements_gpu.txt
@@ -64,6 +64,7 @@ termcolor==1.1.0
 toml==0.10.2
 toolz==0.11.1
 typing-extensions==3.7.4.3
+tzwhere==3.0.3
 urllib3==1.26.3
 Werkzeug==1.0.1
 wget==3.2
diff --git a/run_wrf_dh.py b/run_wrf_dh.py
index b9d4cd02..a406f193 100644
--- a/run_wrf_dh.py
+++ b/run_wrf_dh.py
@@ -8,7 +8,7 @@ from mlair.workflows import DefaultWorkflow
 from mlair.helpers import remove_items
 from mlair.configuration.defaults import DEFAULT_PLOT_LIST
 
-from mlair.model_modules.model_class import MyPaperModel, MyLSTMModel, MyCNNModel
+from mlair.model_modules.model_class import MyLSTMModel, MyCNNModel
 import os
 
 
@@ -33,6 +33,7 @@ def main(parser_args):
         evaluate_bootstraps=True,  number_of_bootstraps=30, create_new_bootstraps=True,  #
         plot_list=plots,
         model_name_for_plots='SecModel',
+        time_zone="UTC",
 #         competitors=["test_model", "test_model2"],
 #         competitor_path=os.path.join(os.getcwd(), "data", "comp_test"),
         competitors=["baseline", "sector_baseline"],
@@ -40,9 +41,9 @@ def main(parser_args):
         train_min_length=1, val_min_length=1, test_min_length=1,
         # data_handler=DataHandlerSingleStation,
         # data_handler=DataHandlerSingleGridColumn,
-        upsampeling=True,
+        upsampeling=False,
         epochs=20,
-        window_lead_time=2,
+        window_lead_time=4,
         window_history_size=6,
         # ('Germany', (5.98865807458, 47.3024876979, 15.0169958839, 54.983104153))
         stations=["coords__48_8479__10_0963", "coords__51_8376__14_1417",
@@ -65,9 +66,9 @@ def main(parser_args):
         # data_path='/home/felix/Data/WRF-Chem/upload_aura_2021-02-24/2009/',
         # data_path='/home/felix/Data/WRF-Chem/test_cut_nc/',
         # data_path='/home/felix/Data/WRF-Chem/test_cut_nc_joint',
-        data_path="/home/felix/Data/WRF-Chem/test_cut_nc_joint/short_test",
+        # data_path="/home/felix/Data/WRF-Chem/test_cut_nc_joint/short_test",
         # data_path = "/p/scratch/deepacf/kleinert1/IASS_proc_monthyl",
-        # data_path="/media/felix/INTENSO/WRF_CHEM/JFM_2009",
+        data_path="/media/felix/INTENSO/WRF_CHEM/JFM_2009",
 
         # external data coords
         external_coords_file='/home/felix/Data/WRF-Chem/test_cut_nc/coords.nc',
@@ -93,8 +94,11 @@ def main(parser_args):
         },
         variables=['T2', 'o3', 'wdir10ll', 'wspd10ll', 'no', 'no2', 'co', 'PSFC', 'PBLH', 'CLDFRA'],
         target_var='o3',
-        statistics_per_var={'T2': None, 'o3': None, 'wdir10ll': None, 'wspd10ll': None,
-                            'no': None, 'no2': None, 'co': None, 'PSFC': None, 'PBLH': None, 'CLDFRA': None, },
+        # statistics_per_var={'T2': None, 'o3': None, 'wdir10ll': None, 'wspd10ll': None,
+        #                     'no': None, 'no2': None, 'co': None, 'PSFC': None, 'PBLH': None, 'CLDFRA': None, },
+        statistics_per_var={'T2': "average_values", 'o3': "dma8eu", 'wdir10ll': "average_values",
+                            'wspd10ll': "average_values", 'no': "dma8eu", 'no2': "dma8eu", 'co': "dma8eu",
+                            'PSFC': "average_values", 'PBLH': "average_values", 'CLDFRA': "average_values", },
         # separate_vars=["o3", "o3Sect", "o3SectLeft", "o3SectRight"],
         separate_vars=["o3", "o3Sect"],
         # variables=['T2', 'Q2', 'PBLH', 'U10ll', 'V10ll', 'wdir10ll', 'wspd10ll'],
@@ -108,35 +112,36 @@ def main(parser_args):
         radius=100,  # km
 
         start='2009-01-01',
-        end='2009-01-04',
+        # end='2009-01-04',
         #end='2009-01-31',
-        # end='2009-03-31',
+        end='2009-03-31',
         
         train_start='2009-01-01',
-        train_end='2009-01-02',
+        # train_end='2009-01-02',
         # train_start='2009-01-01',
         #train_end='2009-01-15',
-        # train_end='2009-02-28',
+        train_end='2009-02-28',
         
-        val_start='2009-01-02',
-        val_end='2009-01-03',
+        # val_start='2009-01-02',
+        # val_end='2009-01-03',
         ###################################
         #val_start='2009-01-15',
         #val_end='2009-01-22',
         ###################################
-        # val_start='2009-03-01',
-        # val_end='2009-03-14',
+        val_start='2009-03-01',
+        val_end='2009-03-14',
 
-        test_start='2009-01-03',
-        test_end='2009-01-04',
+        # test_start='2009-01-03',
+        # test_end='2009-01-04',
         ###################################
         #test_start='2009-01-22',
         #test_end='2009-01-31',
         ###################################
-        # test_start='2009-03-15',
-        # test_end='2009-03-31',
+        test_start='2009-03-15',
+        test_end='2009-03-31',
         
-        sampling='hourly',
+        # sampling='hourly',
+        sampling=("hourly", "daily"),
 
         interpolation_limit=0,
         # as_image_like_data_format=False,
diff --git a/run_wrf_dh_sector.py b/run_wrf_dh_sector.py
index d77bfd74..688631ff 100644
--- a/run_wrf_dh_sector.py
+++ b/run_wrf_dh_sector.py
@@ -7,7 +7,7 @@ from mlair.data_handler.data_handler_wrf_chem import DataHandlerWRF, DataHandler
 from mlair.workflows import DefaultWorkflow
 from mlair.helpers import remove_items
 from mlair.configuration.defaults import DEFAULT_PLOT_LIST
-from mlair.model_modules.model_class import MyPaperModel, MyLSTMModel, MyCNNModel, MyCNNModelSect
+from mlair.model_modules.model_class import MyLSTMModel, MyCNNModel, MyCNNModelSect
 
 import os
 
-- 
GitLab