From 37c6a279b5386038475d581a95eb62b17283ae39 Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Wed, 14 Apr 2021 15:35:09 +0200
Subject: [PATCH] PlotPeriodogram now can plot periodogram before and after
 filtering

---
 .../data_handler_mixed_sampling.py            |   6 +-
 mlair/plotting/preprocessing_plotting.py      | 275 +++++++++++++-----
 mlair/run_modules/pre_processing.py           |   3 +-
 3 files changed, 203 insertions(+), 81 deletions(-)

diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py
index 75e9e645..e2516257 100644
--- a/mlair/data_handler/data_handler_mixed_sampling.py
+++ b/mlair/data_handler/data_handler_mixed_sampling.py
@@ -56,7 +56,7 @@ class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation):
         kwargs.update({parameter_name: parameter})
 
     def make_input_target(self):
-        self._data = list(map(self.load_and_interpolate, [0, 1]))  # load input (0) and target (1) data
+        self._data = tuple(map(self.load_and_interpolate, [0, 1]))  # load input (0) and target (1) data
         self.set_inputs_and_targets()
 
     def load_and_interpolate(self, ind) -> [xr.DataArray, pd.DataFrame]:
@@ -110,7 +110,7 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi
         A KZ filter is applied on the input data that has hourly resolution. Lables Y are provided as aggregated values
         with daily resolution.
         """
-        self._data = list(map(self.load_and_interpolate, [0, 1]))  # load input (0) and target (1) data
+        self._data = tuple(map(self.load_and_interpolate, [0, 1]))  # load input (0) and target (1) data
         self.set_inputs_and_targets()
         self.apply_kz_filter()
 
@@ -158,7 +158,7 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi
     def _extract_lazy(self, lazy_data):
         _data, self.meta, _input_data, _target_data, self.cutoff_period, self.cutoff_period_days = lazy_data
         start_inp, end_inp = self.update_start_end(0)
-        self._data = list(map(lambda x: self._slice_prep(_data[x], *self.update_start_end(x)), [0, 1]))
+        self._data = tuple(map(lambda x: self._slice_prep(_data[x], *self.update_start_end(x)), [0, 1]))
         self.input_data = self._slice_prep(_input_data, start_inp, end_inp)
         self.target_data = self._slice_prep(_target_data, self.start, self.end)
 
diff --git a/mlair/plotting/preprocessing_plotting.py b/mlair/plotting/preprocessing_plotting.py
index 53b29568..84df9b4b 100644
--- a/mlair/plotting/preprocessing_plotting.py
+++ b/mlair/plotting/preprocessing_plotting.py
@@ -4,6 +4,9 @@ __date__ = '2021-04-13'
 
 from typing import List, Dict
 import os
+import logging
+import multiprocessing
+import psutil
 
 import numpy as np
 import pandas as pd
@@ -18,7 +21,7 @@ from mlair.plotting.abstract_plot_class import AbstractPlotClass
 
 
 @TimeTrackingWrapper
-class PlotStationMap(AbstractPlotClass):
+class PlotStationMap(AbstractPlotClass):  # pragma: no cover
     """
     Plot geographical overview of all used stations as squares.
 
@@ -144,7 +147,7 @@ class PlotStationMap(AbstractPlotClass):
 
 
 @TimeTrackingWrapper
-class PlotAvailability(AbstractPlotClass):
+class PlotAvailability(AbstractPlotClass):  # pragma: no cover
     """
     Create data availablility plot similar to Gantt plot.
 
@@ -271,7 +274,7 @@ class PlotAvailability(AbstractPlotClass):
 
 
 @TimeTrackingWrapper
-class PlotAvailabilityHistogram(AbstractPlotClass):
+class PlotAvailabilityHistogram(AbstractPlotClass):  # pragma: no cover
     """
     Create data availability plots as histogram.
 
@@ -441,7 +444,7 @@ class PlotAvailabilityHistogram(AbstractPlotClass):
         plt.tight_layout()
 
 
-class PlotPeriodogram(AbstractPlotClass):
+class PlotPeriodogram(AbstractPlotClass):  # pragma: no cover
     """
     Create Lomb-Scargle periodogram in raw input and target data. The Lomb-Scargle version can deal with missing values.
 
@@ -457,89 +460,171 @@ class PlotPeriodogram(AbstractPlotClass):
     """
 
     def __init__(self, generator: Dict[str, DataCollection], plot_folder: str = ".", plot_name="periodogram",
-                 variables_dim="variables", time_dim="datetime", sampling="daily"):
+                 variables_dim="variables", time_dim="datetime", sampling="daily", use_multiprocessing=False):
         super().__init__(plot_folder, plot_name)
         self.variables_dim = variables_dim
         self.time_dim = time_dim
 
-        for pos, s in enumerate(sampling if isinstance(sampling, tuple) else (sampling,)):
+        for pos, s in enumerate(sampling if isinstance(sampling, tuple) else (sampling, sampling)):
             self._sampling = s
-            self._prepare_pgram(generator, pos)
+            self._add_text = {0: "input", 1: "target"}[pos]
+            multiple = self._has_filter_dimension(generator[0], pos)
+            self._prepare_pgram(generator, pos, multiple, use_multiprocessing=use_multiprocessing)
             self._plot(raw=True)
             self._plot(raw=False)
             self._plot_total(raw=True)
             self._plot_total(raw=False)
+            if multiple > 1:
+                self._plot_difference()
 
-    def _prepare_pgram(self, generator, pos):
-        raw_data = dict()
-        plot_data = dict()
-        plot_data_raw = dict()
-        plot_data_mean = dict()
+    @staticmethod
+    def _has_filter_dimension(g, pos):
+        # check if coords raw data differs from input / target data
+        check_data = g.id_class
+        if "filter" not in [check_data.input_data, check_data.target_data][pos].coords.dims:
+            return 1
+        else:
+            if len(set(check_data._data[0].coords.dims).symmetric_difference(check_data.input_data.coords.dims)) > 0:
+                return g.id_class.input_data.coords["filter"].shape[0]
+            else:
+                return 1
+
+    @TimeTrackingWrapper
+    def _prepare_pgram(self, generator, pos, multiple=1, use_multiprocessing=False):
+        """
+        Create periodogram data.
+        """
+        self.raw_data = []
+        self.plot_data = []
+        self.plot_data_raw = []
+        self.plot_data_mean = []
+        iter = range(multiple if multiple == 1 else multiple + 1)
+        for m in iter:
+            plot_data_single = dict()
+            plot_data_raw_single = dict()
+            plot_data_mean_single = dict()
+            raw_data_single = self._prepare_pgram_parallel_gen(generator, m, pos, use_multiprocessing)
+            # raw_data_single = self._prepare_pgram_parallel_var(generator, m, pos, use_multiprocessing)
+            self.f_index = np.logspace(-3, 0 if self._sampling == "daily" else np.log10(24), 1000)
+            for var in raw_data_single.keys():
+                pgram_com = []
+                pgram_mean = 0
+                all_data = raw_data_single[var]
+                pgram_mean_raw = np.zeros((len(self.f_index), len(all_data)))
+                for i, (f, pgram) in enumerate(all_data):
+                    d = np.interp(self.f_index, f, pgram)
+                    pgram_com.append(d)
+                    pgram_mean += d
+                    pgram_mean_raw[:, i] = d
+                pgram_mean /= len(all_data)
+                plot_data_single[var] = pgram_com
+                plot_data_mean_single[var] = (self.f_index, pgram_mean)
+                plot_data_raw_single[var] = (self.f_index, pgram_mean_raw)
+            self.plot_data.append(plot_data_single)
+            self.plot_data_mean.append(plot_data_mean_single)
+            self.plot_data_raw.append(plot_data_raw_single)
+
+    def _prepare_pgram_parallel_var(self, generator, m, pos, use_multiprocessing):
+        """Implementation of data preprocessing using parallel variables element processing."""
+        raw_data_single = dict()
         for g in generator:
-            print(g)
-            d = g.id_class._data
+            if m == 0:
+                d = g.id_class._data
+            else:
+                gd = g.id_class
+                filter_sel = {"filter": gd.input_data.coords["filter"][m - 1]}
+                d = (gd.input_data.sel(filter_sel), gd.target_data)
             d = d[pos] if isinstance(d, tuple) else d
-            for var in d[self.variables_dim].values:
-                var_str = str(var)
-                d_var = d.loc[{self.variables_dim: var}].squeeze().dropna(self.time_dim)
-                t = (d_var.datetime - d_var.datetime[0]).astype("timedelta64[h]").values / np.timedelta64(1, "D")
-                f, pgram = LombScargle(t, d_var.values.flatten(), nterms=1).autopower()
-                raw_data[var_str] = [(f, pgram)] if var_str not in raw_data.keys() else raw_data[var_str] + [(f, pgram)]
-        self.f_index = np.logspace(-3, 0 if self._sampling == "daily" else np.log10(24), 1000)
-        for var in raw_data.keys():
-            pgram_com = []
-            pgram_mean = 0
-            all_data = raw_data[var]
-            pgram_mean_raw = np.zeros((len(self.f_index), len(all_data)))
-            for i, (f, pgram) in enumerate(all_data):
-                d = np.interp(self.f_index, f, pgram)
-                pgram_com.append(d)
-                pgram_mean += d
-                pgram_mean_raw[:, i] = d
-            pgram_mean /= len(all_data)
-            plot_data[var] = pgram_com
-            plot_data_mean[var] = (self.f_index, pgram_mean)
-            plot_data_raw[var] = (self.f_index, pgram_mean_raw)
-        self.plot_data = plot_data
-        self.plot_data_mean = plot_data_mean
-        self.plot_data_raw = plot_data_raw
+            res = []
+            if multiprocessing.cpu_count() > 1 and use_multiprocessing:  # parallel solution
+                pool = multiprocessing.Pool(
+                    min([psutil.cpu_count(logical=False), len(d[self.variables_dim].values),
+                         16]))  # use only physical cpus
+                output = [
+                    pool.apply_async(f_proc,
+                                     args=(var, d.loc[{self.variables_dim: var}].squeeze().dropna(self.time_dim)))
+                    for var in d[self.variables_dim].values]
+                for i, p in enumerate(output):
+                    res.append(p.get())
+            else:  # serial solution
+                for var in d[self.variables_dim].values:
+                    res.append(f_proc(var, d.loc[{self.variables_dim: var}].squeeze().dropna(self.time_dim)))
+            for (var_str, f, pgram) in res:
+                if var_str not in raw_data_single.keys():
+                    raw_data_single[var_str] = [(f, pgram)]
+                else:
+                    raw_data_single[var_str] = raw_data_single[var_str] + [(f, pgram)]
+        return raw_data_single
+
+    def _prepare_pgram_parallel_gen(self, generator, m, pos, use_multiprocessing):
+        """Implementation of data preprocessing using parallel generator element processing."""
+        raw_data_single = dict()
+        res = []
+        if multiprocessing.cpu_count() > 1 and use_multiprocessing:  # parallel solution
+            pool = multiprocessing.Pool(
+                min([psutil.cpu_count(logical=False), len(generator), 16]))  # use only physical cpus
+            output = [
+                pool.apply_async(f_proc_2, args=(g, m, pos, self.variables_dim, self.time_dim))
+                for g in generator]
+            for i, p in enumerate(output):
+                res.append(p.get())
+        else:
+            for g in generator:
+                res.append(f_proc_2(g, m, pos, self.variables_dim, self.time_dim))
+        for res_dict in res:
+            for k, v in res_dict.items():
+                if k not in raw_data_single.keys():
+                    raw_data_single[k] = v
+                else:
+                    raw_data_single[k] = raw_data_single[k] + v
+        return raw_data_single
 
     @staticmethod
-    def _add_annotation_line(pos, div, lims, unit):
+    def _add_annotation_line(ax, pos, div, lims, unit):
         for p in to_list(pos):  # per year
-            plt.vlines(p / div, *lims, "black")
-            plt.text(p / div, lims[0], r"%s$%s^{-1}$" % (p, unit), rotation="vertical", rotation_mode="anchor")
+            ax.vlines(p / div, *lims, "black")
+            ax.text(p / div, lims[0], r"%s$%s^{-1}$" % (p, unit), rotation="vertical", rotation_mode="anchor")
+
+    def _format_figure(self, ax, var_name="total"):
+        """
+        Set log scale on both axis, add labels and annotation lines, and set title.
+        :param ax: current ax object
+        :param var_name: name of variable that will be included in the title
+        """
+        ax.set_yscale('log')
+        ax.set_xscale('log')
+        ax.set_ylabel("power", fontsize='x-large')
+        ax.set_xlabel("frequency $[day^{-1}$]", fontsize='x-large')
+        lims = ax.get_ylim()
+        self._add_annotation_line(ax, [1, 2, 3], 365.25, lims, "yr")  # per year
+        self._add_annotation_line(ax, 1, 365.25 / 12, lims, "m")  # per month
+        self._add_annotation_line(ax, 1, 7, lims, "w")  # per week
+        self._add_annotation_line(ax, [1, 0.5], 1, lims, "d")  # per day
+        if self._sampling == "hourly":
+            self._add_annotation_line(ax, 2, 1, lims, "d")  # per day
+            self._add_annotation_line(ax, [1, 0.5], 1 / 24., lims, "h")  # per hour
+        title = f"Periodogram ({var_name})"
+        ax.set_title(title)
 
     def _plot(self, raw=True):
         plot_path = os.path.join(os.path.abspath(self.plot_folder),
-                                 f"{self.plot_name}{'_raw' if raw else ''}_{self._sampling}.pdf")
+                                 f"{self.plot_name}{'_raw' if raw else ''}_{self._sampling}_{self._add_text}.pdf")
         pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path)
-        for var in self.plot_data.keys():
+        plot_data = self.plot_data[0]
+        plot_data_mean = self.plot_data_mean[0]
+        for var in plot_data.keys():
             fig, ax = plt.subplots()
             if raw is True:
-                for pgram in self.plot_data[var]:
+                for pgram in plot_data[var]:
                     ax.plot(self.f_index, pgram, "lightblue")
-                ax.plot(*self.plot_data_mean[var], "blue")
+                ax.plot(*plot_data_mean[var], "blue")
             else:
-                ma = pd.DataFrame(np.vstack(self.plot_data[var]).T).rolling(5, center=True, axis=0)
+                ma = pd.DataFrame(np.vstack(plot_data[var]).T).rolling(5, center=True, axis=0)
                 mean = ma.mean().mean(axis=1).values.flatten()
                 upper, lower = ma.max().mean(axis=1).values.flatten(), ma.min().mean(axis=1).values.flatten()
                 ax.plot(self.f_index, mean, "blue")
                 ax.fill_between(self.f_index, lower, upper, color="lightblue")
-            plt.yscale("log")
-            plt.xscale("log")
-            ax.set_ylabel("power", fontsize='x-large')
-            ax.set_xlabel("frequency $[day^{-1}$]", fontsize='x-large')
-            lims = ax.get_ylim()
-            self._add_annotation_line([1, 2, 3], 365.25, lims, "yr")  # per year
-            self._add_annotation_line(1, 365.25 / 12, lims, "m")  # per month
-            self._add_annotation_line(1, 7, lims, "w")  # per week
-            self._add_annotation_line([1, 0.5], 1, lims, "d")  # per day
-            if self._sampling == "hourly":
-                self._add_annotation_line(2, 1, lims, "d")  # per day
-                self._add_annotation_line([1, 0.5], 1 / 24., lims, "h")  # per hour
-            title = f"Periodogram ({var})"
-            plt.title(title)
+            self._format_figure(ax, var)
             pdf_pages.savefig()
         # close all open figures / plots
         pdf_pages.close()
@@ -547,12 +632,13 @@ class PlotPeriodogram(AbstractPlotClass):
 
     def _plot_total(self, raw=True):
         plot_path = os.path.join(os.path.abspath(self.plot_folder),
-                                 f"{self.plot_name}{'_raw' if raw else ''}_{self._sampling}_total.pdf")
+                                 f"{self.plot_name}{'_raw' if raw else ''}_{self._sampling}_{self._add_text}_total.pdf")
         pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path)
+        plot_data_raw = self.plot_data_raw[0]
         fig, ax = plt.subplots()
         res = None
-        for var in self.plot_data_raw.keys():
-            d_var = self.plot_data_raw[var][1]
+        for var in plot_data_raw.keys():
+            d_var = plot_data_raw[var][1]
             res = d_var if res is None else np.concatenate((res, d_var), axis=-1)
         if raw is True:
             for i in range(res.shape[1]):
@@ -564,21 +650,56 @@ class PlotPeriodogram(AbstractPlotClass):
             upper, lower = ma.max().mean(axis=1).values.flatten(), ma.min().mean(axis=1).values.flatten()
             ax.plot(self.f_index, mean, "blue")
             ax.fill_between(self.f_index, lower, upper, color="lightblue")
-        plt.yscale("log")
-        plt.xscale("log")
-        ax.set_ylabel("power", fontsize='x-large')
-        ax.set_xlabel("frequency $[day^{-1}$]", fontsize='x-large')
-        lims = ax.get_ylim()
-        self._add_annotation_line([1, 2, 3], 365.25, lims, "yr")  # per year
-        self._add_annotation_line(1, 365.25 / 12, lims, "m")  # per month
-        self._add_annotation_line(1, 7, lims, "w")  # per week
-        self._add_annotation_line([1, 0.5], 1, lims, "d")  # per day
-        if self._sampling == "hourly":
-            self._add_annotation_line(2, 1, lims, "d")  # per day
-            self._add_annotation_line([1, 0.5], 1 / 24., lims, "h")  # per hour
-        title = f"Periodogram (total)"
-        plt.title(title)
+        self._format_figure(ax, "total")
         pdf_pages.savefig()
         # close all open figures / plots
         pdf_pages.close()
         plt.close('all')
+
+    def _plot_difference(self):
+        plot_name = f"{self.plot_name}_{self._sampling}_{self._add_text}_filter.pdf"
+        plot_path = os.path.join(os.path.abspath(self.plot_folder), plot_name)
+        logging.info(f"... plotting {plot_name}")
+        pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path)
+        colors = ["blue", "red", "green", "orange"]
+        max_iter = len(self.plot_data)
+        var_keys = self.plot_data[0].keys()
+        for var in var_keys:
+            fig, ax = plt.subplots()
+            for i in reversed(range(max_iter)):
+                plot_data = self.plot_data[i]
+                c = colors[i]
+                ma = pd.DataFrame(np.vstack(plot_data[var]).T).rolling(5, center=True, axis=0)
+                mean = ma.mean().mean(axis=1).values.flatten()
+                ax.plot(self.f_index, mean, c)
+                if i < 1:
+                    upper, lower = ma.max().mean(axis=1).values.flatten(), ma.min().mean(axis=1).values.flatten()
+                    ax.fill_between(self.f_index, lower, upper, color="light" + c, alpha=0.5)
+            self._format_figure(ax, var)
+            pdf_pages.savefig()
+        # close all open figures / plots
+        pdf_pages.close()
+        plt.close('all')
+
+
+def f_proc(var, d_var):
+    var_str = str(var)
+    t = (d_var.datetime - d_var.datetime[0]).astype("timedelta64[h]").values / np.timedelta64(1, "D")
+    f, pgram = LombScargle(t, d_var.values.flatten(), nterms=1).autopower()
+    return var_str, f, pgram
+
+
+def f_proc_2(g, m, pos, variables_dim, time_dim):
+    raw_data_single = dict()
+    if m == 0:
+        d = g.id_class._data
+    else:
+        gd = g.id_class
+        filter_sel = {"filter": gd.input_data.coords["filter"][m - 1]}
+        d = (gd.input_data.sel(filter_sel), gd.target_data)
+    d = d[pos] if isinstance(d, tuple) else d
+    for var in d[variables_dim].values:
+        d_var = d.loc[{variables_dim: var}].squeeze().dropna(time_dim)
+        var_str, f, pgram = f_proc(var, d_var)
+        raw_data_single[var_str] = [(f, pgram)]
+    return raw_data_single
diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py
index d1ec0c60..148c34a0 100644
--- a/mlair/run_modules/pre_processing.py
+++ b/mlair/run_modules/pre_processing.py
@@ -346,6 +346,7 @@ class PreProcessing(RunEnvironment):
         plot_path: str = self.data_store.get("plot_path")
 
         sampling = self.data_store.get("sampling")
+        use_multiprocessing = self.data_store.get("use_multiprocessing")
 
         try:
             if "PlotStationMap" in plot_list:
@@ -383,7 +384,7 @@ class PreProcessing(RunEnvironment):
         try:
             if "PlotPeriodogram" in plot_list:
                 PlotPeriodogram(train_data, plot_folder=plot_path, time_dim=time_dim, variables_dim=target_dim,
-                                sampling=sampling)
+                                sampling=sampling, use_multiprocessing=use_multiprocessing)
 
         except Exception as e:
             logging.error(f"Could not create plot PlotPeriodogram due to the following error: {e}")
-- 
GitLab