From c043134c15b3fc7d8db04b07946bfd0e646dccaa Mon Sep 17 00:00:00 2001 From: Felix Kleinert <f.kleinert@fz-juelich.de> Date: Tue, 19 Jul 2022 09:13:38 +0200 Subject: [PATCH] add PlotStationsPerGridBox --- mlair/data_handler/data_handler_wrf_chem.py | 6 ++ mlair/plotting/data_insight_plotting.py | 73 ++++++++++++++++++++- mlair/run_modules/post_processing.py | 12 +++- 3 files changed, 89 insertions(+), 2 deletions(-) diff --git a/mlair/data_handler/data_handler_wrf_chem.py b/mlair/data_handler/data_handler_wrf_chem.py index ae1fc9f9..3d336e6c 100644 --- a/mlair/data_handler/data_handler_wrf_chem.py +++ b/mlair/data_handler/data_handler_wrf_chem.py @@ -898,6 +898,9 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): self.rechunk_values = rechunk_values self.date_format_of_nc_file = date_format_of_nc_file self.as_image_like_data_format = as_image_like_data_format + self.coords = None + self.nearest_coords = None + self.nearest_icoords = None self.time_zone = time_zone self.target_time_type = target_time_type @@ -1003,6 +1006,9 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): meta_col_names = ['station_name', 'station_lon', 'station_lat', 'station_alt'] with self.loader as loader: + self.coords = loader.get_coordinates() + self.nearest_coords = loader.get_nearest_coords() + self.nearest_icoords = loader.get_nearest_icoords() if self._logical_z_coord_name is None: self._logical_z_coord_name = loader.logical_z_coord_name # # select defined variables at grid box or grid coloumn based on nearest icoords diff --git a/mlair/plotting/data_insight_plotting.py b/mlair/plotting/data_insight_plotting.py index db2b3340..31056c52 100644 --- a/mlair/plotting/data_insight_plotting.py +++ b/mlair/plotting/data_insight_plotting.py @@ -17,9 +17,10 @@ import matplotlib # matplotlib.use("Agg") from matplotlib import lines as mlines, pyplot as plt, patches as mpatches, dates as mdates from astropy.timeseries import LombScargle +from collections import Counter from mlair.data_handler import DataCollection -from mlair.helpers import TimeTrackingWrapper, to_list, remove_items +from mlair.helpers import TimeTrackingWrapper, to_list, remove_items, tables from mlair.plotting.abstract_plot_class import AbstractPlotClass @@ -1271,3 +1272,73 @@ class PlotFirFilter(AbstractPlotClass): # pragma: no cover file = os.path.join(self.plot_folder, "plot_data.pickle") with open(file, "wb") as f: dill.dump(data, f) + + +class PlotStationsPerGridBox(AbstractPlotClass): + def __init__(self, plot_folder: str, data_set, report_path): + file_name = "Stations_per_grid_box" + super().__init__(plot_folder, f"{file_name}.pdf") + self.data_set = data_set + self.report_path = report_path + self.plot_data = self._count_boxes() + self.store_report(file_name) + self.base_plotname = self.plot_name + + self._plot(plot_type="bar") + self._save(bbox_inches="tight") + + self._plot(plot_type="hist") + self._save(bbox_inches="tight") + + def _count_boxes(self): + nearest_icoords = [] + + for data in self.data_set: + data_collection_set = data._collection[0] + # n_icoo = data_collection_set.nearest_icoords + # n_coo = data_collection_set.nearest_coords + # e_coo = data_collection_set.coords + nearest_icoords.append('_'.join([str(v[0]) for v in data_collection_set.nearest_icoords.values()])) + + icoords_count = Counter(nearest_icoords) + df = pd.DataFrame(dict(icoords_count), index=["Number of Stations"]).T + df.index.name = "Grid box (id)" + df = df.sort_values(by=[df.columns[0], df.index.name]) + return df + + def _plot(self, plot_type): + if plot_type == "bar": + self._plot_bar() + elif plot_type == "hist": + self._plot_hist() + else: + raise ValueError(f"`plot_type' must be 'bar' or hist, but is {plot_type}" ) + self.plot_name = f"{self.base_plotname}_{plot_type}" + + def _plot_bar(self): + fig, ax = plt.subplots() + df = self.plot_data + df.plot(kind="bar", ax=ax, legend=False) + ax.set_xlabel(df.index.name) + ax.set_ylabel("Number of Pseudo-Stations in Grid box") + + def _plot_hist(self): + fig, ax = plt.subplots() + df = self.plot_data + labels, counts = np.unique(df, return_counts=True) + + ax.bar(labels, counts, align='center') + plt.gca().set_xticks(labels) + + ax.set_ylabel("Number of Pseudo-Stations") + ax.set_xlabel("Number of Pseudo-Stations in Grid box") + + def store_report(self, file_name): + self.plot_data.to_csv(os.path.join(self.report_path, f"{file_name}.csv"), sep=";") + df_descr = self.plot_data.describe().T + column_format = tables.create_column_format_for_tex(df_descr) + tables.save_to_tex(self.report_path, f"{file_name}.tex", column_format=column_format, df=df_descr) + tables.save_to_md(self.report_path, f"{file_name}.md", df=df_descr) + + + diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index c38d92f6..9ddc39e9 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -26,7 +26,7 @@ from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotClima PlotCompetitiveSkillScore, PlotTimeSeries, PlotFeatureImportanceSkillScore, PlotConditionalQuantiles, \ PlotSeparationOfScales, PlotSampleUncertaintyFromBootstrap, PlotSectorialSkillScore, PlotScoresOnMap from mlair.plotting.data_insight_plotting import PlotStationMap, PlotAvailability, PlotAvailabilityHistogram, \ - PlotPeriodogram, PlotDataHistogram + PlotPeriodogram, PlotDataHistogram, PlotStationsPerGridBox from mlair.run_modules.run_environment import RunEnvironment @@ -535,6 +535,16 @@ class PostProcessing(RunEnvironment): iter_dim = self.data_store.get("iter_dim") separate_vars = self.data_store.get_default("separate_vars", None) + + try: + if "PlotStationsPerGridBox" in plot_list: + report_path = os.path.join(self.data_store.get("experiment_path"), "latex_report") + PlotStationsPerGridBox(plot_folder=self.plot_path, data_set=self.train_data, report_path=report_path) + except Exception as e: + logging.error(f"Could not create plot PlotStationsPerGridBox due to the following error:" + f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") + + try: if ("filter" in self.test_data[0].get_X(as_numpy=False)[0].coords) and ( "PlotSeparationOfScales" in plot_list): -- GitLab