diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py index 87fc83b0c5d97631b9b0e01aa490be20c107ed1f..3a57d9febc714c81a68c21facab55957eabf32d9 100644 --- a/mlair/data_handler/default_data_handler.py +++ b/mlair/data_handler/default_data_handler.py @@ -299,6 +299,7 @@ class DefaultDataHandler(AbstractDataHandler): for p in output: dh, s = p.get() _inner() + pool.close() else: # serial solution logging.info("use serial transformation approach") for station in set_stations: diff --git a/mlair/plotting/data_insight_plotting.py b/mlair/plotting/data_insight_plotting.py index 84df9b4b5390cef04d989c70bfca448b1db7d7b5..79c26522b90fa28391c4b358c125c90a199a345a 100644 --- a/mlair/plotting/data_insight_plotting.py +++ b/mlair/plotting/data_insight_plotting.py @@ -468,26 +468,27 @@ class PlotPeriodogram(AbstractPlotClass): # pragma: no cover for pos, s in enumerate(sampling if isinstance(sampling, tuple) else (sampling, sampling)): self._sampling = s self._add_text = {0: "input", 1: "target"}[pos] - multiple = self._has_filter_dimension(generator[0], pos) + multiple, label_names = self._has_filter_dimension(generator[0], pos) self._prepare_pgram(generator, pos, multiple, use_multiprocessing=use_multiprocessing) self._plot(raw=True) self._plot(raw=False) self._plot_total(raw=True) self._plot_total(raw=False) if multiple > 1: - self._plot_difference() + self._plot_difference(label_names) @staticmethod def _has_filter_dimension(g, pos): # check if coords raw data differs from input / target data check_data = g.id_class if "filter" not in [check_data.input_data, check_data.target_data][pos].coords.dims: - return 1 + return 1, [] else: if len(set(check_data._data[0].coords.dims).symmetric_difference(check_data.input_data.coords.dims)) > 0: - return g.id_class.input_data.coords["filter"].shape[0] + return g.id_class.input_data.coords["filter"].shape[0], g.id_class.input_data.coords[ + "filter"].values.tolist() else: - return 1 + return 1, [] @TimeTrackingWrapper def _prepare_pgram(self, generator, pos, multiple=1, use_multiprocessing=False): @@ -546,6 +547,7 @@ class PlotPeriodogram(AbstractPlotClass): # pragma: no cover for var in d[self.variables_dim].values] for i, p in enumerate(output): res.append(p.get()) + pool.close() else: # serial solution for var in d[self.variables_dim].values: res.append(f_proc(var, d.loc[{self.variables_dim: var}].squeeze().dropna(self.time_dim))) @@ -568,6 +570,7 @@ class PlotPeriodogram(AbstractPlotClass): # pragma: no cover for g in generator] for i, p in enumerate(output): res.append(p.get()) + pool.close() else: for g in generator: res.append(f_proc_2(g, m, pos, self.variables_dim, self.time_dim)) @@ -656,12 +659,13 @@ class PlotPeriodogram(AbstractPlotClass): # pragma: no cover pdf_pages.close() plt.close('all') - def _plot_difference(self): + def _plot_difference(self, label_names): plot_name = f"{self.plot_name}_{self._sampling}_{self._add_text}_filter.pdf" plot_path = os.path.join(os.path.abspath(self.plot_folder), plot_name) logging.info(f"... plotting {plot_name}") pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path) - colors = ["blue", "red", "green", "orange"] + colors = ["blue", "red", "green", "orange", "purple", "black", "grey"] + label_names = ["orig"] + label_names max_iter = len(self.plot_data) var_keys = self.plot_data[0].keys() for var in var_keys: @@ -671,11 +675,12 @@ class PlotPeriodogram(AbstractPlotClass): # pragma: no cover c = colors[i] ma = pd.DataFrame(np.vstack(plot_data[var]).T).rolling(5, center=True, axis=0) mean = ma.mean().mean(axis=1).values.flatten() - ax.plot(self.f_index, mean, c) + ax.plot(self.f_index, mean, c, label=label_names[i]) if i < 1: upper, lower = ma.max().mean(axis=1).values.flatten(), ma.min().mean(axis=1).values.flatten() - ax.fill_between(self.f_index, lower, upper, color="light" + c, alpha=0.5) + ax.fill_between(self.f_index, lower, upper, color="light" + c, alpha=0.5, label=None) self._format_figure(ax, var) + ax.legend() pdf_pages.savefig() # close all open figures / plots pdf_pages.close() diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py index 4edf8e965c7140be067428b4ee1c596b8a85b312..68164b6fa3c6b95727f634baebd40e988482abd5 100644 --- a/mlair/run_modules/pre_processing.py +++ b/mlair/run_modules/pre_processing.py @@ -256,6 +256,7 @@ class PreProcessing(RunEnvironment): if dh is not None: collection.add(dh) valid_stations.append(s) + pool.close() else: # serial solution logging.info("use serial validate station approach") for station in set_stations: