From c043134c15b3fc7d8db04b07946bfd0e646dccaa Mon Sep 17 00:00:00 2001
From: Felix Kleinert <f.kleinert@fz-juelich.de>
Date: Tue, 19 Jul 2022 09:13:38 +0200
Subject: [PATCH] add PlotStationsPerGridBox

---
 mlair/data_handler/data_handler_wrf_chem.py |  6 ++
 mlair/plotting/data_insight_plotting.py     | 73 ++++++++++++++++++++-
 mlair/run_modules/post_processing.py        | 12 +++-
 3 files changed, 89 insertions(+), 2 deletions(-)

diff --git a/mlair/data_handler/data_handler_wrf_chem.py b/mlair/data_handler/data_handler_wrf_chem.py
index ae1fc9f9..3d336e6c 100644
--- a/mlair/data_handler/data_handler_wrf_chem.py
+++ b/mlair/data_handler/data_handler_wrf_chem.py
@@ -898,6 +898,9 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation):
         self.rechunk_values = rechunk_values
         self.date_format_of_nc_file = date_format_of_nc_file
         self.as_image_like_data_format = as_image_like_data_format
+        self.coords = None
+        self.nearest_coords = None
+        self.nearest_icoords = None
 
         self.time_zone = time_zone
         self.target_time_type = target_time_type
@@ -1003,6 +1006,9 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation):
             meta_col_names = ['station_name', 'station_lon', 'station_lat', 'station_alt']
 
         with self.loader as loader:
+            self.coords = loader.get_coordinates()
+            self.nearest_coords = loader.get_nearest_coords()
+            self.nearest_icoords = loader.get_nearest_icoords()
             if self._logical_z_coord_name is None:
                 self._logical_z_coord_name = loader.logical_z_coord_name
             # # select defined variables at grid box or grid coloumn based on nearest icoords
diff --git a/mlair/plotting/data_insight_plotting.py b/mlair/plotting/data_insight_plotting.py
index db2b3340..31056c52 100644
--- a/mlair/plotting/data_insight_plotting.py
+++ b/mlair/plotting/data_insight_plotting.py
@@ -17,9 +17,10 @@ import matplotlib
 # matplotlib.use("Agg")
 from matplotlib import lines as mlines, pyplot as plt, patches as mpatches, dates as mdates
 from astropy.timeseries import LombScargle
+from collections import Counter
 
 from mlair.data_handler import DataCollection
-from mlair.helpers import TimeTrackingWrapper, to_list, remove_items
+from mlair.helpers import TimeTrackingWrapper, to_list, remove_items, tables
 from mlair.plotting.abstract_plot_class import AbstractPlotClass
 
 
@@ -1271,3 +1272,73 @@ class PlotFirFilter(AbstractPlotClass):  # pragma: no cover
         file = os.path.join(self.plot_folder, "plot_data.pickle")
         with open(file, "wb") as f:
             dill.dump(data, f)
+
+
+class PlotStationsPerGridBox(AbstractPlotClass):
+    def __init__(self, plot_folder: str, data_set, report_path):
+        file_name = "Stations_per_grid_box"
+        super().__init__(plot_folder, f"{file_name}.pdf")
+        self.data_set = data_set
+        self.report_path = report_path
+        self.plot_data = self._count_boxes()
+        self.store_report(file_name)
+        self.base_plotname = self.plot_name
+
+        self._plot(plot_type="bar")
+        self._save(bbox_inches="tight")
+
+        self._plot(plot_type="hist")
+        self._save(bbox_inches="tight")
+
+    def _count_boxes(self):
+        nearest_icoords = []
+
+        for data in self.data_set:
+            data_collection_set = data._collection[0]
+            # n_icoo = data_collection_set.nearest_icoords
+            # n_coo = data_collection_set.nearest_coords
+            # e_coo = data_collection_set.coords
+            nearest_icoords.append('_'.join([str(v[0]) for v in  data_collection_set.nearest_icoords.values()]))
+
+        icoords_count = Counter(nearest_icoords)
+        df = pd.DataFrame(dict(icoords_count), index=["Number of Stations"]).T
+        df.index.name = "Grid box (id)"
+        df = df.sort_values(by=[df.columns[0], df.index.name])
+        return df
+
+    def _plot(self, plot_type):
+        if plot_type == "bar":
+            self._plot_bar()
+        elif plot_type == "hist":
+            self._plot_hist()
+        else:
+            raise ValueError(f"`plot_type' must be 'bar' or hist, but is {plot_type}" )
+        self.plot_name = f"{self.base_plotname}_{plot_type}"
+
+    def _plot_bar(self):
+        fig, ax = plt.subplots()
+        df = self.plot_data
+        df.plot(kind="bar", ax=ax, legend=False)
+        ax.set_xlabel(df.index.name)
+        ax.set_ylabel("Number of Pseudo-Stations in Grid box")
+
+    def _plot_hist(self):
+        fig, ax = plt.subplots()
+        df = self.plot_data
+        labels, counts = np.unique(df, return_counts=True)
+
+        ax.bar(labels, counts, align='center')
+        plt.gca().set_xticks(labels)
+
+        ax.set_ylabel("Number of Pseudo-Stations")
+        ax.set_xlabel("Number of Pseudo-Stations in Grid box")
+
+    def store_report(self, file_name):
+        self.plot_data.to_csv(os.path.join(self.report_path, f"{file_name}.csv"), sep=";")
+        df_descr = self.plot_data.describe().T
+        column_format = tables.create_column_format_for_tex(df_descr)
+        tables.save_to_tex(self.report_path, f"{file_name}.tex", column_format=column_format, df=df_descr)
+        tables.save_to_md(self.report_path, f"{file_name}.md", df=df_descr)
+
+
+
diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index c38d92f6..9ddc39e9 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -26,7 +26,7 @@ from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotClima
     PlotCompetitiveSkillScore, PlotTimeSeries, PlotFeatureImportanceSkillScore, PlotConditionalQuantiles, \
     PlotSeparationOfScales, PlotSampleUncertaintyFromBootstrap, PlotSectorialSkillScore, PlotScoresOnMap
 from mlair.plotting.data_insight_plotting import PlotStationMap, PlotAvailability, PlotAvailabilityHistogram, \
-    PlotPeriodogram, PlotDataHistogram
+    PlotPeriodogram, PlotDataHistogram, PlotStationsPerGridBox
 from mlair.run_modules.run_environment import RunEnvironment
 
 
@@ -535,6 +535,16 @@ class PostProcessing(RunEnvironment):
         iter_dim = self.data_store.get("iter_dim")
         separate_vars = self.data_store.get_default("separate_vars", None)
 
+
+        try:
+            if "PlotStationsPerGridBox" in plot_list:
+                report_path = os.path.join(self.data_store.get("experiment_path"), "latex_report")
+                PlotStationsPerGridBox(plot_folder=self.plot_path, data_set=self.train_data, report_path=report_path)
+        except Exception as e:
+            logging.error(f"Could not create plot PlotStationsPerGridBox due to the following error:"
+                          f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}")
+
+
         try:
             if ("filter" in self.test_data[0].get_X(as_numpy=False)[0].coords) and (
                     "PlotSeparationOfScales" in plot_list):
-- 
GitLab