diff --git a/tests/test_trends/__init__.py b/tests/test_trends/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_trends/conftest.py b/tests/test_trends/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..f6b5fab0a2ca95a7f0469a7477149cb7449f0490 --- /dev/null +++ b/tests/test_trends/conftest.py @@ -0,0 +1,6 @@ +import pytest + + +@pytest.fixture(scope="session") +def mean_data(data): + return data.resample("MS").mean() diff --git a/tests/test_trends/test_interface.py b/tests/test_trends/test_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..777d06aaf08bd25e28998797cebbfa8bf1400363 --- /dev/null +++ b/tests/test_trends/test_interface.py @@ -0,0 +1,38 @@ +import pandas as pd +import pytest + +from toarstats.trends.interface import calculate_trend + + +def test_calculate_trend_wrong_method(mean_data): + with pytest.raises( + ValueError, match="abc is not recognized, must be 'OLS' or 'quant'" + ): + trend = calculate_trend("abc", mean_data) + + +def test_calculate_trend_wrong_quantile(mean_data): + with pytest.raises( + ValueError, match="The quantiles must be strictly between 0 and 1." + ): + trend = calculate_trend("quant", mean_data, 1.2) + + +@pytest.mark.filterwarnings("ignore:Maximum number of iterations") +@pytest.mark.slow +def test_calculate_trend_one_quantile(mean_data): + trend = calculate_trend("quant", mean_data, 0.5) + assert list(trend.keys()) == [0.5] + + +@pytest.mark.filterwarnings("ignore:Maximum number of iterations") +@pytest.mark.slow +def test_calculate_trend_multiple_quantiles(mean_data): + trends = calculate_trend("quant", mean_data, [0.25, 0.5, 0.75]) + assert list(trends.keys()) == [0.25, 0.5, 0.75] + + +@pytest.mark.slow +def test_calculate_trend_ols(mean_data): + trend = calculate_trend("OLS", mean_data) + assert list(trend.keys()) == ["trend", "uncertainty", "p_value"] diff --git a/tests/test_trends/test_ols.py b/tests/test_trends/test_ols.py new file mode 100644 index 0000000000000000000000000000000000000000..4175a88fbfd85db2cb1d9db3e3ae6cf128eddc6e --- /dev/null +++ b/tests/test_trends/test_ols.py @@ -0,0 +1,18 @@ +import pandas as pd +import pytest + +from toarstats.trends.ols import ols + + +@pytest.mark.slow +def test_ols_sample_data(): + data = pd.DataFrame( + { + "datetime": [-2, 0, 1, 2, 4, 5, 6, 7, 10, 11], + "value": [ + 0.2, 2.1, -1.1, 0.01, 4.03, 2.37, -2.11, 1.96, -0.98, -1.32 + ] + } + ) + trend = ols(data) + assert list(trend.keys()) == ["trend", "uncertainty", "p_value"] diff --git a/tests/test_trends/test_quant_reg.py b/tests/test_trends/test_quant_reg.py new file mode 100644 index 0000000000000000000000000000000000000000..7fc47811daf59b98539c117cf77f0746ce4dbfa0 --- /dev/null +++ b/tests/test_trends/test_quant_reg.py @@ -0,0 +1,18 @@ +import pandas as pd +import pytest + +from toarstats.trends.quant_reg import quant_reg + + +@pytest.mark.slow +def test_quant_reg_sample_data(): + data = pd.DataFrame( + { + "datetime": [-2, 0, 1, 2, 4, 5, 6, 7, 10, 11], + "value": [ + 0.2, 2.1, -1.1, 0.01, 4.03, 2.37, -2.11, 1.96, -0.98, -1.32 + ] + } + ) + trend = quant_reg(data, 0.5) + assert list(trend.keys()) == ["trend", "uncertainty", "p_value"] diff --git a/tests/test_trends/test_utils.py b/tests/test_trends/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c3b142967e650a0deaa66691fd9999d68228b61c --- /dev/null +++ b/tests/test_trends/test_utils.py @@ -0,0 +1,56 @@ +import pytest + +from toarstats.trends.utils import ( + calculate_anomalies, calculate_seasonal_cycle, moving_block_bootstrap +) + + +def test_calculate_seasonal_cycle_monthly_data(mean_data): + seasonal_cycle = calculate_seasonal_cycle( + mean_data.rename(columns={"values": "value"}) + ) + assert seasonal_cycle.index.tolist() == list(range(12)) + + +def test_calculate_anomalies_monthly_data(mean_data): + anomalies_series = calculate_anomalies( + mean_data.rename(columns={"values": "value"}) + ) + assert anomalies_series.columns.tolist() == ["value", "datetime"] + assert anomalies_series["datetime"].tolist() == sorted( + anomalies_series["datetime"].tolist() + ) + + +@pytest.mark.slow +def test_moving_block_bootstrap_quant_reg(mean_data): + anomalies_series = calculate_anomalies( + mean_data.rename(columns={"values": "value"}) + ) + mbb = moving_block_bootstrap("quant_reg", anomalies_series, 0.5) + assert len(mbb) == 1000 + + +def test_moving_block_bootstrap_quant_reg_few_samples(mean_data): + anomalies_series = calculate_anomalies( + mean_data.rename(columns={"values": "value"}) + ) + mbb = moving_block_bootstrap("quant_reg", anomalies_series, 0.5, 50) + assert len(mbb) == 50 + + +@pytest.mark.slow +def test_moving_block_bootstrap_ols(mean_data): + anomalies_series = calculate_anomalies( + mean_data.rename(columns={"values": "value"}) + ) + mbb = moving_block_bootstrap("OLS", anomalies_series) + assert len(mbb) == 1000 + + +def test_moving_block_bootstrap_ols_few_samples(mean_data): + anomalies_series = calculate_anomalies( + mean_data.rename(columns={"values": "value"}) + ) + mbb = moving_block_bootstrap("OLS", anomalies_series, num_samples=50) + assert len(mbb) == 50