From beebb5d222c62c76978561e27e5d4074b36b5225 Mon Sep 17 00:00:00 2001 From: Niklas Selke <n.selke@fz-juelich.de> Date: Tue, 6 Jun 2023 09:38:13 +0200 Subject: [PATCH] Added tests for the 'trends' subpackage. --- tests/test_trends/__init__.py | 0 tests/test_trends/conftest.py | 6 ++++ tests/test_trends/test_interface.py | 38 ++++++++++++++++++++ tests/test_trends/test_ols.py | 18 ++++++++++ tests/test_trends/test_quant_reg.py | 18 ++++++++++ tests/test_trends/test_utils.py | 56 +++++++++++++++++++++++++++++ 6 files changed, 136 insertions(+) create mode 100644 tests/test_trends/__init__.py create mode 100644 tests/test_trends/conftest.py create mode 100644 tests/test_trends/test_interface.py create mode 100644 tests/test_trends/test_ols.py create mode 100644 tests/test_trends/test_quant_reg.py create mode 100644 tests/test_trends/test_utils.py diff --git a/tests/test_trends/__init__.py b/tests/test_trends/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_trends/conftest.py b/tests/test_trends/conftest.py new file mode 100644 index 0000000..f6b5fab --- /dev/null +++ b/tests/test_trends/conftest.py @@ -0,0 +1,6 @@ +import pytest + + +@pytest.fixture(scope="session") +def mean_data(data): + return data.resample("MS").mean() diff --git a/tests/test_trends/test_interface.py b/tests/test_trends/test_interface.py new file mode 100644 index 0000000..777d06a --- /dev/null +++ b/tests/test_trends/test_interface.py @@ -0,0 +1,38 @@ +import pandas as pd +import pytest + +from toarstats.trends.interface import calculate_trend + + +def test_calculate_trend_wrong_method(mean_data): + with pytest.raises( + ValueError, match="abc is not recognized, must be 'OLS' or 'quant'" + ): + trend = calculate_trend("abc", mean_data) + + +def test_calculate_trend_wrong_quantile(mean_data): + with pytest.raises( + ValueError, match="The quantiles must be strictly between 0 and 1." + ): + trend = calculate_trend("quant", mean_data, 1.2) + + +@pytest.mark.filterwarnings("ignore:Maximum number of iterations") +@pytest.mark.slow +def test_calculate_trend_one_quantile(mean_data): + trend = calculate_trend("quant", mean_data, 0.5) + assert list(trend.keys()) == [0.5] + + +@pytest.mark.filterwarnings("ignore:Maximum number of iterations") +@pytest.mark.slow +def test_calculate_trend_multiple_quantiles(mean_data): + trends = calculate_trend("quant", mean_data, [0.25, 0.5, 0.75]) + assert list(trends.keys()) == [0.25, 0.5, 0.75] + + +@pytest.mark.slow +def test_calculate_trend_ols(mean_data): + trend = calculate_trend("OLS", mean_data) + assert list(trend.keys()) == ["trend", "uncertainty", "p_value"] diff --git a/tests/test_trends/test_ols.py b/tests/test_trends/test_ols.py new file mode 100644 index 0000000..4175a88 --- /dev/null +++ b/tests/test_trends/test_ols.py @@ -0,0 +1,18 @@ +import pandas as pd +import pytest + +from toarstats.trends.ols import ols + + +@pytest.mark.slow +def test_ols_sample_data(): + data = pd.DataFrame( + { + "datetime": [-2, 0, 1, 2, 4, 5, 6, 7, 10, 11], + "value": [ + 0.2, 2.1, -1.1, 0.01, 4.03, 2.37, -2.11, 1.96, -0.98, -1.32 + ] + } + ) + trend = ols(data) + assert list(trend.keys()) == ["trend", "uncertainty", "p_value"] diff --git a/tests/test_trends/test_quant_reg.py b/tests/test_trends/test_quant_reg.py new file mode 100644 index 0000000..7fc4781 --- /dev/null +++ b/tests/test_trends/test_quant_reg.py @@ -0,0 +1,18 @@ +import pandas as pd +import pytest + +from toarstats.trends.quant_reg import quant_reg + + +@pytest.mark.slow +def test_quant_reg_sample_data(): + data = pd.DataFrame( + { + "datetime": [-2, 0, 1, 2, 4, 5, 6, 7, 10, 11], + "value": [ + 0.2, 2.1, -1.1, 0.01, 4.03, 2.37, -2.11, 1.96, -0.98, -1.32 + ] + } + ) + trend = quant_reg(data, 0.5) + assert list(trend.keys()) == ["trend", "uncertainty", "p_value"] diff --git a/tests/test_trends/test_utils.py b/tests/test_trends/test_utils.py new file mode 100644 index 0000000..c3b1429 --- /dev/null +++ b/tests/test_trends/test_utils.py @@ -0,0 +1,56 @@ +import pytest + +from toarstats.trends.utils import ( + calculate_anomalies, calculate_seasonal_cycle, moving_block_bootstrap +) + + +def test_calculate_seasonal_cycle_monthly_data(mean_data): + seasonal_cycle = calculate_seasonal_cycle( + mean_data.rename(columns={"values": "value"}) + ) + assert seasonal_cycle.index.tolist() == list(range(12)) + + +def test_calculate_anomalies_monthly_data(mean_data): + anomalies_series = calculate_anomalies( + mean_data.rename(columns={"values": "value"}) + ) + assert anomalies_series.columns.tolist() == ["value", "datetime"] + assert anomalies_series["datetime"].tolist() == sorted( + anomalies_series["datetime"].tolist() + ) + + +@pytest.mark.slow +def test_moving_block_bootstrap_quant_reg(mean_data): + anomalies_series = calculate_anomalies( + mean_data.rename(columns={"values": "value"}) + ) + mbb = moving_block_bootstrap("quant_reg", anomalies_series, 0.5) + assert len(mbb) == 1000 + + +def test_moving_block_bootstrap_quant_reg_few_samples(mean_data): + anomalies_series = calculate_anomalies( + mean_data.rename(columns={"values": "value"}) + ) + mbb = moving_block_bootstrap("quant_reg", anomalies_series, 0.5, 50) + assert len(mbb) == 50 + + +@pytest.mark.slow +def test_moving_block_bootstrap_ols(mean_data): + anomalies_series = calculate_anomalies( + mean_data.rename(columns={"values": "value"}) + ) + mbb = moving_block_bootstrap("OLS", anomalies_series) + assert len(mbb) == 1000 + + +def test_moving_block_bootstrap_ols_few_samples(mean_data): + anomalies_series = calculate_anomalies( + mean_data.rename(columns={"values": "value"}) + ) + mbb = moving_block_bootstrap("OLS", anomalies_series, num_samples=50) + assert len(mbb) == 50 -- GitLab