diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index ff74da37ed9706d9674bf773249c84e4e1e7424b..a633dec1d65244aacb531d7dc1f94c8e2931d29c 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -21,7 +21,6 @@ from mlair.model_modules.linear_model import OrdinaryLeastSquaredModel from mlair.model_modules import AbstractModelClass from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotClimatologicalSkillScore, \ PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotConditionalQuantiles, PlotSeparationOfScales -from mlair.plotting.preprocessing_plotting import PlotStationMap, PlotAvailability, PlotAvailabilityHistogram from mlair.run_modules.run_environment import RunEnvironment @@ -325,23 +324,6 @@ class PostProcessing(RunEnvironment): except Exception as e: logging.error(f"Could not create plot PlotConditionalQuantiles due to the following error: {e}") - try: - if "PlotStationMap" in plot_list: - if self.data_store.get("hostname")[:2] in self.data_store.get("hpc_hosts") or self.data_store.get( - "hostname")[:6] in self.data_store.get("hpc_hosts"): - logging.warning( - f"Skip 'PlotStationMap` because running on a hpc node: {self.data_store.get('hostname')}") - else: - gens = [(self.train_data, {"marker": 5, "ms": 9}), - (self.val_data, {"marker": 6, "ms": 9}), - (self.test_data, {"marker": 4, "ms": 9})] - PlotStationMap(generators=gens, plot_folder=self.plot_path) - gens = [(self.train_val_data, {"marker": 8, "ms": 9}), - (self.test_data, {"marker": 9, "ms": 9})] - PlotStationMap(generators=gens, plot_folder=self.plot_path, plot_name="station_map_var") - except Exception as e: - logging.error(f"Could not create plot PlotStationMap due to the following error: {e}") - try: if "PlotMonthlySummary" in plot_list: PlotMonthlySummary(self.test_data.keys(), path, r"forecasts_%s_test.nc", self.target_var, @@ -372,22 +354,6 @@ class PostProcessing(RunEnvironment): except Exception as e: logging.error(f"Could not create plot PlotTimeSeries due to the following error: {e}") - try: - if "PlotAvailability" in plot_list: - avail_data = {"train": self.train_data, "val": self.val_data, "test": self.test_data} - PlotAvailability(avail_data, plot_folder=self.plot_path, time_dimension=time_dim, - window_dimension=window_dim) - except Exception as e: - logging.error(f"Could not create plot PlotAvailability due to the following error: {e}") - - try: - if "PlotAvailabilityHistogram" in plot_list: - avail_data = {"train": self.train_data, "val": self.val_data, "test": self.test_data} - PlotAvailabilityHistogram(avail_data, plot_folder=self.plot_path, station_dim=iter_dim, - history_dim=window_dim) - except Exception as e: - logging.error(f"Could not create plot PlotAvailabilityHistogram due to the following error: {e}") - def calculate_test_score(self): """Evaluate test score of model and save locally.""" diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py index f59a4e89ced738c9198619ec0d117df2edf3ee93..3c2670aadf06c99b4d2491a9d5cc885dd605421b 100644 --- a/mlair/run_modules/pre_processing.py +++ b/mlair/run_modules/pre_processing.py @@ -18,6 +18,7 @@ from mlair.helpers import TimeTracking, to_list, tables from mlair.configuration import path_config from mlair.helpers.join import EmptyQueryResult from mlair.run_modules.run_environment import RunEnvironment +from mlair.plotting.preprocessing_plotting import PlotStationMap, PlotAvailability, PlotAvailabilityHistogram class PreProcessing(RunEnvironment): @@ -67,6 +68,7 @@ class PreProcessing(RunEnvironment): self.split_train_val_test() self.report_pre_processing() self.prepare_competitors() + self.plot() def report_pre_processing(self): """Log some metrics on data and create latex report.""" @@ -327,6 +329,54 @@ class PreProcessing(RunEnvironment): else: logging.info("No preparation required because no competitor was provided to the workflow.") + def plot(self): + logging.info("Run plotting routines...") + + plot_list = self.data_store.get("plot_list", "postprocessing") + time_dim = self.data_store.get("time_dim") + window_dim = self.data_store.get("window_dim") + target_dim = self.data_store.get("target_dim") + iter_dim = self.data_store.get("iter_dim") + + train_data = self.data_store.get("data_collection", "train") + val_data = self.data_store.get("data_collection", "val") + test_data = self.data_store.get("data_collection", "test") + train_val_data = self.data_store.get("data_collection", "train_val") + plot_path: str = self.data_store.get("plot_path") + + try: + if "PlotStationMap" in plot_list: + if self.data_store.get("hostname")[:2] in self.data_store.get("hpc_hosts") or self.data_store.get( + "hostname")[:6] in self.data_store.get("hpc_hosts"): + logging.warning( + f"Skip 'PlotStationMap` because running on a hpc node: {self.data_store.get('hostname')}") + else: + gens = [(train_data, {"marker": 5, "ms": 9}), + (val_data, {"marker": 6, "ms": 9}), + (test_data, {"marker": 4, "ms": 9})] + PlotStationMap(generators=gens, plot_folder=plot_path) + gens = [(train_val_data, {"marker": 8, "ms": 9}), + (test_data, {"marker": 9, "ms": 9})] + PlotStationMap(generators=gens, plot_folder=plot_path, plot_name="station_map_var") + except Exception as e: + logging.error(f"Could not create plot PlotStationMap due to the following error: {e}") + + try: + if "PlotAvailability" in plot_list: + avail_data = {"train": train_data, "val": val_data, "test": test_data} + PlotAvailability(avail_data, plot_folder=plot_path, time_dimension=time_dim, + window_dimension=window_dim) + except Exception as e: + logging.error(f"Could not create plot PlotAvailability due to the following error: {e}") + + try: + if "PlotAvailabilityHistogram" in plot_list: + avail_data = {"train": train_data, "val": val_data, "test": test_data} + PlotAvailabilityHistogram(avail_data, plot_folder=plot_path, station_dim=iter_dim, + history_dim=window_dim) + except Exception as e: + logging.error(f"Could not create plot PlotAvailabilityHistogram due to the following error: {e}") + def f_proc(data_handler, station, name_affix, store, **kwargs): """