From beebb5d222c62c76978561e27e5d4074b36b5225 Mon Sep 17 00:00:00 2001
From: Niklas Selke <n.selke@fz-juelich.de>
Date: Tue, 6 Jun 2023 09:38:13 +0200
Subject: [PATCH] Added tests for the 'trends' subpackage.

---
 tests/test_trends/__init__.py       |  0
 tests/test_trends/conftest.py       |  6 ++++
 tests/test_trends/test_interface.py | 38 ++++++++++++++++++++
 tests/test_trends/test_ols.py       | 18 ++++++++++
 tests/test_trends/test_quant_reg.py | 18 ++++++++++
 tests/test_trends/test_utils.py     | 56 +++++++++++++++++++++++++++++
 6 files changed, 136 insertions(+)
 create mode 100644 tests/test_trends/__init__.py
 create mode 100644 tests/test_trends/conftest.py
 create mode 100644 tests/test_trends/test_interface.py
 create mode 100644 tests/test_trends/test_ols.py
 create mode 100644 tests/test_trends/test_quant_reg.py
 create mode 100644 tests/test_trends/test_utils.py

diff --git a/tests/test_trends/__init__.py b/tests/test_trends/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/test_trends/conftest.py b/tests/test_trends/conftest.py
new file mode 100644
index 0000000..f6b5fab
--- /dev/null
+++ b/tests/test_trends/conftest.py
@@ -0,0 +1,6 @@
+import pytest
+
+
+@pytest.fixture(scope="session")
+def mean_data(data):
+    return data.resample("MS").mean()
diff --git a/tests/test_trends/test_interface.py b/tests/test_trends/test_interface.py
new file mode 100644
index 0000000..777d06a
--- /dev/null
+++ b/tests/test_trends/test_interface.py
@@ -0,0 +1,38 @@
+import pandas as pd
+import pytest
+
+from toarstats.trends.interface import calculate_trend
+
+
+def test_calculate_trend_wrong_method(mean_data):
+    with pytest.raises(
+            ValueError, match="abc is not recognized, must be 'OLS' or 'quant'"
+    ):
+        trend = calculate_trend("abc", mean_data)
+
+
+def test_calculate_trend_wrong_quantile(mean_data):
+    with pytest.raises(
+            ValueError, match="The quantiles must be strictly between 0 and 1."
+    ):
+        trend = calculate_trend("quant", mean_data, 1.2)
+
+
+@pytest.mark.filterwarnings("ignore:Maximum number of iterations")
+@pytest.mark.slow
+def test_calculate_trend_one_quantile(mean_data):
+    trend = calculate_trend("quant", mean_data, 0.5)
+    assert list(trend.keys()) == [0.5]
+
+
+@pytest.mark.filterwarnings("ignore:Maximum number of iterations")
+@pytest.mark.slow
+def test_calculate_trend_multiple_quantiles(mean_data):
+    trends = calculate_trend("quant", mean_data, [0.25, 0.5, 0.75])
+    assert list(trends.keys()) == [0.25, 0.5, 0.75]
+
+
+@pytest.mark.slow
+def test_calculate_trend_ols(mean_data):
+    trend = calculate_trend("OLS", mean_data)
+    assert list(trend.keys()) == ["trend", "uncertainty", "p_value"]
diff --git a/tests/test_trends/test_ols.py b/tests/test_trends/test_ols.py
new file mode 100644
index 0000000..4175a88
--- /dev/null
+++ b/tests/test_trends/test_ols.py
@@ -0,0 +1,18 @@
+import pandas as pd
+import pytest
+
+from toarstats.trends.ols import ols
+
+
+@pytest.mark.slow
+def test_ols_sample_data():
+    data = pd.DataFrame(
+        {
+            "datetime": [-2, 0, 1, 2, 4, 5, 6, 7, 10, 11],
+            "value": [
+                0.2, 2.1, -1.1, 0.01, 4.03, 2.37, -2.11, 1.96, -0.98, -1.32
+            ]
+        }
+    )
+    trend = ols(data)
+    assert list(trend.keys()) == ["trend", "uncertainty", "p_value"]
diff --git a/tests/test_trends/test_quant_reg.py b/tests/test_trends/test_quant_reg.py
new file mode 100644
index 0000000..7fc4781
--- /dev/null
+++ b/tests/test_trends/test_quant_reg.py
@@ -0,0 +1,18 @@
+import pandas as pd
+import pytest
+
+from toarstats.trends.quant_reg import quant_reg
+
+
+@pytest.mark.slow
+def test_quant_reg_sample_data():
+    data = pd.DataFrame(
+        {
+            "datetime": [-2, 0, 1, 2, 4, 5, 6, 7, 10, 11],
+            "value": [
+                0.2, 2.1, -1.1, 0.01, 4.03, 2.37, -2.11, 1.96, -0.98, -1.32
+            ]
+        }
+    )
+    trend = quant_reg(data, 0.5)
+    assert list(trend.keys()) == ["trend", "uncertainty", "p_value"]
diff --git a/tests/test_trends/test_utils.py b/tests/test_trends/test_utils.py
new file mode 100644
index 0000000..c3b1429
--- /dev/null
+++ b/tests/test_trends/test_utils.py
@@ -0,0 +1,56 @@
+import pytest
+
+from toarstats.trends.utils import (
+    calculate_anomalies, calculate_seasonal_cycle, moving_block_bootstrap
+)
+
+
+def test_calculate_seasonal_cycle_monthly_data(mean_data):
+    seasonal_cycle = calculate_seasonal_cycle(
+        mean_data.rename(columns={"values": "value"})
+    )
+    assert seasonal_cycle.index.tolist() == list(range(12))
+
+
+def test_calculate_anomalies_monthly_data(mean_data):
+    anomalies_series = calculate_anomalies(
+        mean_data.rename(columns={"values": "value"})
+    )
+    assert anomalies_series.columns.tolist() == ["value", "datetime"]
+    assert anomalies_series["datetime"].tolist() == sorted(
+        anomalies_series["datetime"].tolist()
+    )
+
+
+@pytest.mark.slow
+def test_moving_block_bootstrap_quant_reg(mean_data):
+    anomalies_series = calculate_anomalies(
+        mean_data.rename(columns={"values": "value"})
+    )
+    mbb = moving_block_bootstrap("quant_reg", anomalies_series, 0.5)
+    assert len(mbb) == 1000
+
+
+def test_moving_block_bootstrap_quant_reg_few_samples(mean_data):
+    anomalies_series = calculate_anomalies(
+        mean_data.rename(columns={"values": "value"})
+    )
+    mbb = moving_block_bootstrap("quant_reg", anomalies_series, 0.5, 50)
+    assert len(mbb) == 50
+
+
+@pytest.mark.slow
+def test_moving_block_bootstrap_ols(mean_data):
+    anomalies_series = calculate_anomalies(
+        mean_data.rename(columns={"values": "value"})
+    )
+    mbb = moving_block_bootstrap("OLS", anomalies_series)
+    assert len(mbb) == 1000
+
+
+def test_moving_block_bootstrap_ols_few_samples(mean_data):
+    anomalies_series = calculate_anomalies(
+        mean_data.rename(columns={"values": "value"})
+    )
+    mbb = moving_block_bootstrap("OLS", anomalies_series, num_samples=50)
+    assert len(mbb) == 50
-- 
GitLab