From 23e6524233d038c54453cc6daf5a244185d7af2d Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Thu, 18 Feb 2021 12:52:25 +0100
Subject: [PATCH] box cox is not properly working, a log transformation is
 suitable enough

---
 .../data_handler_single_station.py            | 14 +++--
 mlair/helpers/statistics.py                   | 17 ++++++
 test/test_helpers/test_statistics.py          | 54 ++++++++++++++++++-
 3 files changed, 80 insertions(+), 5 deletions(-)

diff --git a/mlair/data_handler/data_handler_single_station.py b/mlair/data_handler/data_handler_single_station.py
index 4002d478..a894c635 100644
--- a/mlair/data_handler/data_handler_single_station.py
+++ b/mlair/data_handler/data_handler_single_station.py
@@ -178,6 +178,8 @@ class DataHandlerSingleStation(AbstractDataHandler):
                 return statistics.centre(data, dim)
             elif method == "min_max":
                 return statistics.min_max(data, dim)
+            elif method == "log":
+                return statistics.log(data, dim)
             else:
                 raise NotImplementedError
 
@@ -188,6 +190,8 @@ class DataHandlerSingleStation(AbstractDataHandler):
                 return statistics.centre_apply(data, mean), {"mean": mean, "method": method}
             elif method == "min_max":
                 return statistics.min_max_apply(data, min, max), {"min": min, "max": max, "method": method}
+            elif method == "log":
+                return statistics.log_apply(data, mean, std), {"mean": mean, "std": std, "method": method}
             else:
                 raise NotImplementedError
 
@@ -601,12 +605,14 @@ class DataHandlerSingleStation(AbstractDataHandler):
         """
 
         def f_inverse(data, method, mean=None, std=None, min=None, max=None):
-            if method == 'standardise':
+            if method == "standardise":
                 return statistics.standardise_inverse(data, mean, std)
-            elif method == 'centre':
+            elif method == "centre":
                 return statistics.centre_inverse(data, mean)
-            elif method == 'min_max':
-                raise statistics.min_max_inverse(data, min, max)
+            elif method == "min_max":
+                return statistics.min_max_inverse(data, min, max)
+            elif method == "log":
+                return statistics.log_inverse(data, mean, std)
             else:
                 raise NotImplementedError
 
diff --git a/mlair/helpers/statistics.py b/mlair/helpers/statistics.py
index 65574f4c..0a25644c 100644
--- a/mlair/helpers/statistics.py
+++ b/mlair/helpers/statistics.py
@@ -37,6 +37,8 @@ def apply_inverse_transformation(data: Data, method: str = "standardise", mean:
         return centre_inverse(data, mean)
     elif method == 'min_max':  # pragma: no branch
         return min_max_inverse(data, min, max)
+    elif method == "log":
+        return log_inverse(data, mean, std)
     else:
         raise NotImplementedError
 
@@ -152,6 +154,21 @@ def min_max_apply(data: Data, min: Data, max: Data) -> Data:
     return (data - min) / (max - min)
 
 
+def log(data: Data, dim: Union[str, int]) -> Tuple[Data, Dict[(str, Data)]]:
+    transformed_standardized, opts = standardise(np.log1p(data), dim)
+    opts.update({"method": "log"})
+    return transformed_standardized, opts
+
+
+def log_apply(data: Data, mean: Data, std: Data) -> Data:
+    return standardise_apply(np.log1p(data), mean, std)
+
+
+def log_inverse(data: Data, mean: Data, std: Data) -> Data:
+    data_rescaled = standardise_inverse(data, mean, std)
+    return np.expm1(data_rescaled)
+
+
 def mean_squared_error(a, b):
     """Calculate mean squared error."""
     return np.square(a - b).mean()
diff --git a/test/test_helpers/test_statistics.py b/test/test_helpers/test_statistics.py
index e0febe1e..2a77f080 100644
--- a/test/test_helpers/test_statistics.py
+++ b/test/test_helpers/test_statistics.py
@@ -4,7 +4,7 @@ import pytest
 import xarray as xr
 
 from mlair.helpers.statistics import standardise, standardise_inverse, standardise_apply, centre, centre_inverse, \
-    centre_apply, apply_inverse_transformation, min_max, min_max_inverse, min_max_apply
+    centre_apply, apply_inverse_transformation, min_max, min_max_inverse, min_max_apply, log, log_inverse, log_apply
 
 lazy = pytest.lazy_fixture
 
@@ -16,11 +16,23 @@ def input_data():
                      np.random.normal(10, 1, 3000)]).T
 
 
+@pytest.fixture(scope='module')
+def input_data_gamma():
+    return np.array([np.random.gamma(2, 2, 3000),
+                     np.random.gamma(3, 3, 3000),
+                     np.random.gamma(1, 1, 3000)]).T
+
+
 @pytest.fixture(scope='module')
 def pandas(input_data):
     return pd.DataFrame(input_data)
 
 
+@pytest.fixture(scope='module')
+def pandas_gamma(input_data_gamma):
+    return pd.DataFrame(input_data_gamma)
+
+
 @pytest.fixture(scope='module')
 def pd_mean():
     return [2, 10, 3]
@@ -48,6 +60,13 @@ def xarray(input_data):
     return xr.DataArray(input_data, coords=coords, dims=coords.keys())
 
 
+@pytest.fixture(scope='module')
+def xarray_gamma(input_data_gamma):
+    shape = input_data_gamma.shape
+    coords = {'index': range(shape[0]), 'value': range(shape[1])}
+    return xr.DataArray(input_data_gamma, coords=coords, dims=coords.keys())
+
+
 @pytest.fixture(scope='module')
 def xr_mean(input_data):
     return xr.DataArray([2, 10, 3], coords={'value': range(3)}, dims=['value'])
@@ -169,3 +188,36 @@ class TestMinMax:
         max_expected = (data_orig.max(dim) - dmin) / (dmax - dmin)
         assert np.testing.assert_array_almost_equal(data.min(dim), min_expected, decimal=1) is None
         assert np.testing.assert_array_almost_equal(data.max(dim), max_expected, decimal=1) is None
+
+
+class TestLog:
+
+    @pytest.mark.parametrize('data_orig, dim', [(lazy('pandas_gamma'), 0),
+                                                (lazy('xarray_gamma'), 'index')])
+    def test_standardise(self, data_orig, dim):
+        data, opts = log(data_orig, dim)
+        assert {"method", "mean", "std"} == opts.keys()
+        assert opts["method"] == "log"
+        assert np.testing.assert_almost_equal(data.mean(dim), [0, 0, 0]) is None
+        assert np.testing.assert_almost_equal(data.std(dim), [1, 1, 1]) is None
+
+    @pytest.mark.parametrize('data_orig, dim', [(lazy('pandas_gamma'), 0),
+                                                (lazy('xarray_gamma'), 'index')])
+    def test_standardise_inverse(self, data_orig, dim):
+        data, opts = log(data_orig, dim)
+        data_recovered = log_inverse(data, opts["mean"], opts["std"])
+        assert np.testing.assert_array_almost_equal(data_orig, data_recovered) is None
+
+    @pytest.mark.parametrize('data_orig, dim', [(lazy('pandas_gamma'), 0),
+                                                (lazy('xarray_gamma'), 'index')])
+    def test_apply_standardise_inverse(self, data_orig, dim):
+        data, opts = log(data_orig, dim)
+        data_recovered = apply_inverse_transformation(data, **opts)
+        assert np.testing.assert_array_almost_equal(data_orig, data_recovered) is None
+
+    @pytest.mark.parametrize('data_orig, dim', [(lazy('pandas'), 0),
+                                                (lazy('xarray'), 'index')])
+    def test_standardise_apply(self, data_orig, dim):
+        data_ref, opts = log(data_orig, dim)
+        data_test = log_apply(data_orig, opts["mean"], opts["std"])
+        assert np.testing.assert_array_almost_equal(data_ref, data_test) is None
-- 
GitLab