From 00d65a751a5e2ecb3702485d73b45acc95144fae Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Tue, 17 Mar 2020 15:12:22 +0100
Subject: [PATCH] implementation of extremes in data preparation class

---
 src/data_handling/data_preparation.py      | 64 +++++++++++++++++++++-
 test/test_data_handling/test_bootstraps.py |  2 +-
 2 files changed, 64 insertions(+), 2 deletions(-)

diff --git a/src/data_handling/data_preparation.py b/src/data_handling/data_preparation.py
index e3186778..033859fa 100644
--- a/src/data_handling/data_preparation.py
+++ b/src/data_handling/data_preparation.py
@@ -5,7 +5,7 @@ import datetime as dt
 from functools import reduce
 import logging
 import os
-from typing import Union, List, Iterable
+from typing import Union, List, Iterable, Tuple
 
 import numpy as np
 import pandas as pd
@@ -17,6 +17,8 @@ from src import statistics
 # define a more general date type for type hinting
 date = Union[dt.date, dt.datetime]
 str_or_list = Union[str, List[str]]
+number = Union[float, int]
+num_or_list = Union[number, List[number]]
 
 
 class DataPrep(object):
@@ -58,6 +60,8 @@ class DataPrep(object):
         self.history = None
         self.label = None
         self.observation = None
+        self.extremes_history = None
+        self.extremes_labels = None
         self.kwargs = kwargs
         self.data = None
         self.meta = None
@@ -420,6 +424,64 @@ class DataPrep(object):
     def get_transposed_label(self):
         return self.label.squeeze("Stations").transpose("datetime", "window").copy()
 
+    def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False,
+                          timedelta: Tuple[int, str] = (1, 'm')):
+        """
+        This method extracts extreme values from self.labels which are defined in the argument extreme_values. One can
+        also decide only to extract extremes on the right tail of the distribution. When extreme_values is a list of
+        floats/ints all values larger (and smaller than negative extreme_values; extraction is performed in standardised
+        space) than are extracted iteratively. If for example extreme_values = [1.,2.] then a value of 1.5 would be
+        extracted once (for 0th entry in list), while a 2.5 would be extracted twice (once for each entry). Timedelta is
+        used to mark those extracted values by adding one min to each timestamp. As TOAR Data are hourly one can
+        identify those "artificial" data points later easily. Extreme inputs and labels are stored in
+        self.extremes_history and self.extreme_labels, respectively.
+
+        :param extreme_values: user definition of extreme
+        :param extremes_on_right_tail_only: if False also multiply values which are smaller then -extreme_values,
+            if True only extract values larger than extreme_values
+        :param timedelta: used as arguments for np.timedelta in order to mark extreme values on datetime
+        """
+        # check type if inputs
+        extreme_values = helpers.to_list(extreme_values)
+        extreme_values.sort()
+        for i in extreme_values:
+            if not isinstance(i, number.__args__):
+                raise TypeError(f"Elements of list extreme_values have to be {number.__args__}, but at least element "
+                                f"{i} is type {type(i)}")
+
+        for extr_val in extreme_values:
+            # check if some extreme values are already extracted
+            if not all([self.extremes_labels, self.extremes_history]):
+                # extract extremes based on occurance in labels
+                if extremes_on_right_tail_only:
+                    extreme_label_idx = (self.label > extr_val).any(axis=0).values.reshape(-1,)
+                else:
+                    extreme_label_idx = np.concatenate(((self.label < -extr_val).any(axis=0).values.reshape(-1, 1),
+                                                        (self.label > extr_val).any(axis=0).values.reshape(-1, 1)),
+                                                       axis=1).any(axis=1)
+                extremes_label = self.label[..., extreme_label_idx]
+                extremes_history = self.history[..., extreme_label_idx, :]
+                extremes_label.datetime.values += np.timedelta64(*timedelta)
+                extremes_history.datetime.values += np.timedelta64(*timedelta)
+                self.extremes_labels = extremes_label.squeeze('Stations').transpose('datetime', 'window')
+                self.extremes_history = extremes_history.transpose('datetime', 'window', 'Stations', 'variables')
+            else:  # one extr value iteration is done already: self.extremes_labels is NOT None...
+                if extremes_on_right_tail_only:
+                    extreme_label_idx = (self.extremes_labels > extr_val).any(axis=1).values.reshape(-1,)
+                else:
+                    extreme_label_idx = np.concatenate(((self.extremes_labels < -extr_val).any(axis=1
+                                                                                                 ).values.reshape(-1, 1),
+                                                        (self.extremes_labels > extr_val).any(axis=1
+                                                                                                 ).values.reshape(-1, 1)
+                                                        ), axis=1).any(axis=1)
+                # check on existing extracted extremes to minimise computational costs for comparison
+                extremes_label = self.extremes_labels[extreme_label_idx, ...]
+                extremes_history = self.extremes_history[extreme_label_idx, ...]
+                extremes_label.datetime.values += np.timedelta64(*timedelta)
+                extremes_history.datetime.values += np.timedelta64(*timedelta)
+                self.extremes_labels = xr.concat([self.extremes_labels, extremes_label], dim='datetime')
+                self.extremes_history = xr.concat([self.extremes_history, extremes_history], dim='datetime')
+
 
 if __name__ == "__main__":
     dp = DataPrep('data/', 'dummy', 'DEBW107', ['o3', 'temp'], statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})
diff --git a/test/test_data_handling/test_bootstraps.py b/test/test_data_handling/test_bootstraps.py
index 68789ff6..e7449952 100644
--- a/test/test_data_handling/test_bootstraps.py
+++ b/test/test_data_handling/test_bootstraps.py
@@ -52,7 +52,7 @@ class TestBootstraps:
         boot_no_init.number_bootstraps = 50
         assert boot_no_init.valid_bootstrap_file(station, variables, 20) == (False, 60)
 
-    def test_shuffle_single_variale(self, boot_no_init):
+    def test_shuffle_single_variable(self, boot_no_init):
         data = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]])
         res = boot_no_init.shuffle_single_variable(data, chunks=(2, 3)).compute()
         assert res.shape == data.shape
-- 
GitLab