From 4aa8cba0dd463a615ea07bab26d2565b3973c79e Mon Sep 17 00:00:00 2001 From: Niklas Selke <n.selke@fz-juelich.de> Date: Tue, 6 Jun 2023 10:35:07 +0200 Subject: [PATCH] Added 'num_samples' parameter to the trends interface. --- tests/test_trends/test_interface.py | 25 +++++++++++++++++++++++++ tests/test_trends/test_ols.py | 15 ++++++++++++++- tests/test_trends/test_quant_reg.py | 15 ++++++++++++++- toarstats/trends/interface.py | 11 ++++++++--- toarstats/trends/ols.py | 6 ++++-- toarstats/trends/quant_reg.py | 6 ++++-- 6 files changed, 69 insertions(+), 9 deletions(-) diff --git a/tests/test_trends/test_interface.py b/tests/test_trends/test_interface.py index 777d06a..4600e6b 100644 --- a/tests/test_trends/test_interface.py +++ b/tests/test_trends/test_interface.py @@ -18,6 +18,14 @@ def test_calculate_trend_wrong_quantile(mean_data): trend = calculate_trend("quant", mean_data, 1.2) +def test_calculate_trend_wrong_num_samples(mean_data): + with pytest.raises( + ValueError, + match="The number of samples must be a positive integer." + ): + trend = calculate_trend("quant", mean_data, 0.5, -1) + + @pytest.mark.filterwarnings("ignore:Maximum number of iterations") @pytest.mark.slow def test_calculate_trend_one_quantile(mean_data): @@ -25,6 +33,12 @@ def test_calculate_trend_one_quantile(mean_data): assert list(trend.keys()) == [0.5] +@pytest.mark.filterwarnings("ignore:Maximum number of iterations") +def test_calculate_trend_one_quantile_few_samples(mean_data): + trend = calculate_trend("quant", mean_data, 0.5, 50) + assert list(trend.keys()) == [0.5] + + @pytest.mark.filterwarnings("ignore:Maximum number of iterations") @pytest.mark.slow def test_calculate_trend_multiple_quantiles(mean_data): @@ -32,7 +46,18 @@ def test_calculate_trend_multiple_quantiles(mean_data): assert list(trends.keys()) == [0.25, 0.5, 0.75] +@pytest.mark.filterwarnings("ignore:Maximum number of iterations") +def test_calculate_trend_multiple_quantiles_few_samples(mean_data): + trends = calculate_trend("quant", mean_data, [0.25, 0.5, 0.75], 50) + assert list(trends.keys()) == [0.25, 0.5, 0.75] + + @pytest.mark.slow def test_calculate_trend_ols(mean_data): trend = calculate_trend("OLS", mean_data) assert list(trend.keys()) == ["trend", "uncertainty", "p_value"] + + +def test_calculate_trend_ols_few_samples(mean_data): + trend = calculate_trend("OLS", mean_data, num_samples=50) + assert list(trend.keys()) == ["trend", "uncertainty", "p_value"] diff --git a/tests/test_trends/test_ols.py b/tests/test_trends/test_ols.py index 4175a88..294f03b 100644 --- a/tests/test_trends/test_ols.py +++ b/tests/test_trends/test_ols.py @@ -14,5 +14,18 @@ def test_ols_sample_data(): ] } ) - trend = ols(data) + trend = ols(data, 1000) + assert list(trend.keys()) == ["trend", "uncertainty", "p_value"] + + +def test_ols_sample_data_few_samples(): + data = pd.DataFrame( + { + "datetime": [-2, 0, 1, 2, 4, 5, 6, 7, 10, 11], + "value": [ + 0.2, 2.1, -1.1, 0.01, 4.03, 2.37, -2.11, 1.96, -0.98, -1.32 + ] + } + ) + trend = ols(data, 50) assert list(trend.keys()) == ["trend", "uncertainty", "p_value"] diff --git a/tests/test_trends/test_quant_reg.py b/tests/test_trends/test_quant_reg.py index 7fc4781..98c47a8 100644 --- a/tests/test_trends/test_quant_reg.py +++ b/tests/test_trends/test_quant_reg.py @@ -14,5 +14,18 @@ def test_quant_reg_sample_data(): ] } ) - trend = quant_reg(data, 0.5) + trend = quant_reg(data, 0.5, 1000) + assert list(trend.keys()) == ["trend", "uncertainty", "p_value"] + + +def test_quant_reg_sample_data_few_samples(): + data = pd.DataFrame( + { + "datetime": [-2, 0, 1, 2, 4, 5, 6, 7, 10, 11], + "value": [ + 0.2, 2.1, -1.1, 0.01, 4.03, 2.37, -2.11, 1.96, -0.98, -1.32 + ] + } + ) + trend = quant_reg(data, 0.5, 50) assert list(trend.keys()) == ["trend", "uncertainty", "p_value"] diff --git a/toarstats/trends/interface.py b/toarstats/trends/interface.py index 9998380..800c700 100644 --- a/toarstats/trends/interface.py +++ b/toarstats/trends/interface.py @@ -10,7 +10,7 @@ from toarstats.trends.quant_reg import quant_reg from toarstats.trends.utils import calculate_anomalies -def calculate_trend(method, data, quantiles=None): +def calculate_trend(method, data, quantiles=None, num_samples=1000): """Calculate the trend using the requested method. This function is the public interface for the ``trends`` subpackage. @@ -28,6 +28,8 @@ def calculate_trend(method, data, quantiles=None): :param quantiles: a single quantile or a list of quantiles to calculate, these must be between 0 and 1; only needed when ``method="quant"`` + :param num_samples: number of sampled trends in moving block + bootstrap :raises TypeError: raised if @@ -46,6 +48,7 @@ def calculate_trend(method, data, quantiles=None): - the index and values have different lengths - any ``quantiles`` are not strictly within 0 and 1 with ``method="quantreg"`` + - ``num_samples`` is not a positive integer :return: The result of the fit or a dict of fit results if ``method="quant"`` @@ -62,12 +65,14 @@ def calculate_trend(method, data, quantiles=None): ) if not all(0 < quantile < 1 for quantile in quantile_list): raise ValueError("The quantiles must be strictly between 0 and 1.") + if not isinstance(num_samples, int) or num_samples < 1: + raise ValueError("The number of samples must be a positive integer.") anomalies_series = calculate_anomalies(data_in) if method == "quant": fit = { - quantile: quant_reg(anomalies_series, quantile) + quantile: quant_reg(anomalies_series, quantile, num_samples) for quantile in quantile_list } else: - fit = ols(anomalies_series) + fit = ols(anomalies_series, num_samples) return fit diff --git a/toarstats/trends/ols.py b/toarstats/trends/ols.py index a34cf32..f598d07 100644 --- a/toarstats/trends/ols.py +++ b/toarstats/trends/ols.py @@ -11,17 +11,19 @@ import statsmodels.formula.api as smf from toarstats.trends.utils import moving_block_bootstrap -def ols(data): +def ols(data, num_samples): """Calculate the OLS linear regression. :param data: data containing a list of date time values and associated parameter values on which to calculate the trend + :param num_samples: number of sampled trends in moving block + bootstrap :return: The trend with its uncertainty and p value """ fit = smf.ols("value~datetime", data).fit(method="qr").params - mbb = moving_block_bootstrap("OLS", data) + mbb = moving_block_bootstrap("OLS", data, num_samples=num_samples) fit_se = np.nanstd(mbb, axis=0) fit_pv = 2*scipy.stats.t.sf(x=abs(fit/fit_se), df=len(data)-2) return {"trend": fit, "uncertainty": fit_se, "p_value": fit_pv} diff --git a/toarstats/trends/quant_reg.py b/toarstats/trends/quant_reg.py index dd2d926..dc9d9d9 100644 --- a/toarstats/trends/quant_reg.py +++ b/toarstats/trends/quant_reg.py @@ -11,18 +11,20 @@ import statsmodels.formula.api as smf from toarstats.trends.utils import moving_block_bootstrap -def quant_reg(data, quantile): +def quant_reg(data, quantile, num_samples): """Calculate the quantile regression. :param data: data containing a list of date time values and associated parameter values on which to calculate the trend :param quantile: a single quantile, must be between 0 and 1 + :param num_samples: number of sampled trends in moving block + bootstrap :return: The trend with its uncertainty and p value """ fit = smf.quantreg("value~datetime", data).fit(q=quantile).params - mbb = moving_block_bootstrap("quant", data, quantile) + mbb = moving_block_bootstrap("quant", data, quantile, num_samples) fit_se = np.nanstd(mbb, axis=0) fit_pv = 2*scipy.stats.t.sf(x=abs(fit/fit_se), df=len(data)-2) return {"trend": fit, "uncertainty": fit_se, "p_value": fit_pv} -- GitLab