From ab5136ae25dedc7af8f97f33bd9720d1432605c2 Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Tue, 13 Apr 2021 16:55:33 +0200
Subject: [PATCH] periodogram plot is included in preprocessing plotting for
 first time

---
 HPC_setup/requirements_HDFML_additionals.txt  |  1 +
 HPC_setup/requirements_JUWELS_additionals.txt |  1 +
 mlair/plotting/preprocessing_plotting.py      | 95 ++++++++++++++++++-
 mlair/run_modules/pre_processing.py           | 13 ++-
 requirements.txt                              |  1 +
 requirements_gpu.txt                          |  1 +
 6 files changed, 110 insertions(+), 2 deletions(-)

diff --git a/HPC_setup/requirements_HDFML_additionals.txt b/HPC_setup/requirements_HDFML_additionals.txt
index b2a29fbf..fd22a309 100644
--- a/HPC_setup/requirements_HDFML_additionals.txt
+++ b/HPC_setup/requirements_HDFML_additionals.txt
@@ -1,6 +1,7 @@
 absl-py==0.11.0
 appdirs==1.4.4
 astor==0.8.1
+astropy==4.1
 attrs==20.3.0
 bottleneck==1.3.2
 cached-property==1.5.2
diff --git a/HPC_setup/requirements_JUWELS_additionals.txt b/HPC_setup/requirements_JUWELS_additionals.txt
index b2a29fbf..fd22a309 100644
--- a/HPC_setup/requirements_JUWELS_additionals.txt
+++ b/HPC_setup/requirements_JUWELS_additionals.txt
@@ -1,6 +1,7 @@
 absl-py==0.11.0
 appdirs==1.4.4
 astor==0.8.1
+astropy==4.1
 attrs==20.3.0
 bottleneck==1.3.2
 cached-property==1.5.2
diff --git a/mlair/plotting/preprocessing_plotting.py b/mlair/plotting/preprocessing_plotting.py
index aa61b1f3..da5916fb 100644
--- a/mlair/plotting/preprocessing_plotting.py
+++ b/mlair/plotting/preprocessing_plotting.py
@@ -3,14 +3,17 @@ __author__ = "Lukas Leufen, Felix Kleinert"
 __date__ = '2021-04-13'
 
 from typing import List, Dict
+import os
 
 import numpy as np
 import pandas as pd
 import xarray as xr
+import matplotlib
 from matplotlib import lines as mlines, pyplot as plt, patches as mpatches, dates as mdates
+from astropy.timeseries import LombScargle
 
 from mlair.data_handler import DataCollection
-from mlair.helpers import TimeTrackingWrapper
+from mlair.helpers import TimeTrackingWrapper, to_list
 from mlair.plotting.abstract_plot_class import AbstractPlotClass
 
 
@@ -436,3 +439,93 @@ class PlotAvailabilityHistogram(AbstractPlotClass):
         plt.xlabel('Number of samples')
         plt.xlim((bins[0], bins[-1]))
         plt.tight_layout()
+
+
+class PlotPeriodogram(AbstractPlotClass):
+
+    def __init__(self, generator: Dict[str, DataCollection], plot_folder: str = ".", plot_name="periodogram",
+                 variables_dim="variables", time_dim="datetime", sampling="daily"):
+        super().__init__(plot_folder, plot_name)
+        self.variables_dim = variables_dim
+        self.time_dim = time_dim
+
+        for pos, s in enumerate(sampling if isinstance(sampling, tuple) else (sampling,)):
+            self._sampling = s
+            self._prepare_pgram(generator, pos)
+            self._plot(raw=True)
+            self._plot(raw=False)
+
+    def _prepare_pgram(self, generator, pos):
+        raw_data = dict()
+        plot_data = dict()
+        plot_data_raw = dict()
+        plot_data_mean = dict()
+        for g in generator:
+            print(g)
+            d = g.id_class._data
+            d = d[pos] if isinstance(d, tuple) else d
+            for var in d[self.variables_dim].values:
+                var_str = str(var)
+                d_var = d.loc[{self.variables_dim: var}].squeeze().dropna(self.time_dim)
+                t = (d_var.datetime - d_var.datetime[0]).astype("timedelta64[h]").values / np.timedelta64(1, "D")
+                f, pgram = LombScargle(t, d_var.values.flatten(), nterms=1).autopower()
+                raw_data[var_str] = [(f, pgram)] if var_str not in raw_data.keys() else raw_data[var_str] + [(f, pgram)]
+        self.f_index = np.logspace(-3, 0 if self._sampling == "daily" else np.log10(24), 1000)
+        for var in raw_data.keys():
+            pgram_com = []
+            pgram_mean = 0
+            all_data = raw_data[var]
+            pgram_mean_raw = np.zeros((len(self.f_index), len(all_data)))
+            for i, (f, pgram) in enumerate(all_data):
+                d = np.interp(self.f_index, f, pgram)
+                pgram_com.append(d)
+                pgram_mean += d
+                pgram_mean_raw[:, i] = d
+            pgram_mean /= len(all_data)
+            plot_data[var] = pgram_com
+            plot_data_mean[var] = (self.f_index, pgram_mean)
+            plot_data_raw[var] = (self.f_index, pgram_mean_raw)
+        self.plot_data = plot_data
+        self.plot_data_mean = plot_data_mean
+        self.plot_data_raw = plot_data_raw
+
+    @staticmethod
+    def _add_annotation_line(pos, div, lims, unit):
+        for p in to_list(pos):  # per year
+            plt.vlines(p / div, *lims, "black")
+            plt.text(p / div, lims[0], r"%s$%s^{-1}$" % (p, unit), rotation="vertical", rotation_mode="anchor")
+
+    def _plot(self, raw=True):
+        plot_path = os.path.join(os.path.abspath(self.plot_folder),
+                                 f"{self.plot_name}{'_raw' if raw else ''}_{self._sampling}.pdf")
+        pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path)
+        for var in self.plot_data.keys():
+            fig, ax = plt.subplots()
+            if raw is True:
+                for pgram in self.plot_data[var]:
+                    ax.plot(self.f_index, pgram, "lightblue")
+                ax.plot(*self.plot_data_mean[var], "blue")
+            else:
+                ma = pd.DataFrame(np.vstack(self.plot_data[var]).T).rolling(5, center=True, axis=0)
+                mean = ma.mean().mean(axis=1).values.flatten()
+                upper, lower = ma.max().mean(axis=1).values.flatten(), ma.min().mean(axis=1).values.flatten()
+                ax.plot(self.f_index, mean, "blue")
+                ax.fill_between(self.f_index, lower, upper, color="lightblue")
+            plt.yscale("log")
+            plt.xscale("log")
+            ax.set_ylabel("power", fontsize='x-large')
+            ax.set_xlabel("frequency $[day^{-1}$]", fontsize='x-large')
+            lims = ax.get_ylim()
+            self._add_annotation_line([1, 2, 3], 365.25, lims, "yr")  # per year
+            self._add_annotation_line(1, 365.25 / 12, lims, "m")  # per month
+            self._add_annotation_line(1, 7, lims, "w")  # per week
+            self._add_annotation_line([1, 0.5], 1, lims, "d")  # per day
+            if self._sampling == "hourly":
+                self._add_annotation_line(2, 1, lims, "d")  # per day
+                self._add_annotation_line([1, 0.5], 1 / 24., lims, "h")  # per hour
+            title = f"Periodogram ({var})"
+            plt.title(title)
+            pdf_pages.savefig()
+        # close all open figures / plots
+        pdf_pages.close()
+        plt.close('all')
diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py
index 3c2670aa..d1ec0c60 100644
--- a/mlair/run_modules/pre_processing.py
+++ b/mlair/run_modules/pre_processing.py
@@ -18,7 +18,8 @@ from mlair.helpers import TimeTracking, to_list, tables
 from mlair.configuration import path_config
 from mlair.helpers.join import EmptyQueryResult
 from mlair.run_modules.run_environment import RunEnvironment
-from mlair.plotting.preprocessing_plotting import PlotStationMap, PlotAvailability, PlotAvailabilityHistogram
+from mlair.plotting.preprocessing_plotting import PlotStationMap, PlotAvailability, PlotAvailabilityHistogram, \
+    PlotPeriodogram
 
 
 class PreProcessing(RunEnvironment):
@@ -344,6 +345,8 @@ class PreProcessing(RunEnvironment):
         train_val_data = self.data_store.get("data_collection", "train_val")
         plot_path: str = self.data_store.get("plot_path")
 
+        sampling = self.data_store.get("sampling")
+
         try:
             if "PlotStationMap" in plot_list:
                 if self.data_store.get("hostname")[:2] in self.data_store.get("hpc_hosts") or self.data_store.get(
@@ -377,6 +380,14 @@ class PreProcessing(RunEnvironment):
         except Exception as e:
             logging.error(f"Could not create plot PlotAvailabilityHistogram due to the following error: {e}")
 
+        try:
+            if "PlotPeriodogram" in plot_list:
+                PlotPeriodogram(train_data, plot_folder=plot_path, time_dim=time_dim, variables_dim=target_dim,
+                                sampling=sampling)
+
+        except Exception as e:
+            logging.error(f"Could not create plot PlotPeriodogram due to the following error: {e}")
+
 
 def f_proc(data_handler, station, name_affix, store, **kwargs):
     """
diff --git a/requirements.txt b/requirements.txt
index 85655e23..dba565fb 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,7 @@
 absl-py==0.11.0
 appdirs==1.4.4
 astor==0.8.1
+astropy==4.1
 attrs==20.3.0
 bottleneck==1.3.2
 cached-property==1.5.2
diff --git a/requirements_gpu.txt b/requirements_gpu.txt
index cc189496..f170e1b7 100644
--- a/requirements_gpu.txt
+++ b/requirements_gpu.txt
@@ -1,6 +1,7 @@
 absl-py==0.11.0
 appdirs==1.4.4
 astor==0.8.1
+astropy==4.1
 attrs==20.3.0
 bottleneck==1.3.2
 cached-property==1.5.2
-- 
GitLab