diff --git a/tests/test_trends/test_interface.py b/tests/test_trends/test_interface.py index 777d06aaf08bd25e28998797cebbfa8bf1400363..4600e6bac0f7a760cd09db3ad3de188786854be2 100644 --- a/tests/test_trends/test_interface.py +++ b/tests/test_trends/test_interface.py @@ -18,6 +18,14 @@ def test_calculate_trend_wrong_quantile(mean_data): trend = calculate_trend("quant", mean_data, 1.2) +def test_calculate_trend_wrong_num_samples(mean_data): + with pytest.raises( + ValueError, + match="The number of samples must be a positive integer." + ): + trend = calculate_trend("quant", mean_data, 0.5, -1) + + @pytest.mark.filterwarnings("ignore:Maximum number of iterations") @pytest.mark.slow def test_calculate_trend_one_quantile(mean_data): @@ -25,6 +33,12 @@ def test_calculate_trend_one_quantile(mean_data): assert list(trend.keys()) == [0.5] +@pytest.mark.filterwarnings("ignore:Maximum number of iterations") +def test_calculate_trend_one_quantile_few_samples(mean_data): + trend = calculate_trend("quant", mean_data, 0.5, 50) + assert list(trend.keys()) == [0.5] + + @pytest.mark.filterwarnings("ignore:Maximum number of iterations") @pytest.mark.slow def test_calculate_trend_multiple_quantiles(mean_data): @@ -32,7 +46,18 @@ def test_calculate_trend_multiple_quantiles(mean_data): assert list(trends.keys()) == [0.25, 0.5, 0.75] +@pytest.mark.filterwarnings("ignore:Maximum number of iterations") +def test_calculate_trend_multiple_quantiles_few_samples(mean_data): + trends = calculate_trend("quant", mean_data, [0.25, 0.5, 0.75], 50) + assert list(trends.keys()) == [0.25, 0.5, 0.75] + + @pytest.mark.slow def test_calculate_trend_ols(mean_data): trend = calculate_trend("OLS", mean_data) assert list(trend.keys()) == ["trend", "uncertainty", "p_value"] + + +def test_calculate_trend_ols_few_samples(mean_data): + trend = calculate_trend("OLS", mean_data, num_samples=50) + assert list(trend.keys()) == ["trend", "uncertainty", "p_value"] diff --git a/tests/test_trends/test_ols.py b/tests/test_trends/test_ols.py index 4175a88fbfd85db2cb1d9db3e3ae6cf128eddc6e..294f03bad5e3664dcb5606c0be7843e4c6ba580f 100644 --- a/tests/test_trends/test_ols.py +++ b/tests/test_trends/test_ols.py @@ -14,5 +14,18 @@ def test_ols_sample_data(): ] } ) - trend = ols(data) + trend = ols(data, 1000) + assert list(trend.keys()) == ["trend", "uncertainty", "p_value"] + + +def test_ols_sample_data_few_samples(): + data = pd.DataFrame( + { + "datetime": [-2, 0, 1, 2, 4, 5, 6, 7, 10, 11], + "value": [ + 0.2, 2.1, -1.1, 0.01, 4.03, 2.37, -2.11, 1.96, -0.98, -1.32 + ] + } + ) + trend = ols(data, 50) assert list(trend.keys()) == ["trend", "uncertainty", "p_value"] diff --git a/tests/test_trends/test_quant_reg.py b/tests/test_trends/test_quant_reg.py index 7fc47811daf59b98539c117cf77f0746ce4dbfa0..98c47a8ab545b02694608213cc4b203b6b0b519a 100644 --- a/tests/test_trends/test_quant_reg.py +++ b/tests/test_trends/test_quant_reg.py @@ -14,5 +14,18 @@ def test_quant_reg_sample_data(): ] } ) - trend = quant_reg(data, 0.5) + trend = quant_reg(data, 0.5, 1000) + assert list(trend.keys()) == ["trend", "uncertainty", "p_value"] + + +def test_quant_reg_sample_data_few_samples(): + data = pd.DataFrame( + { + "datetime": [-2, 0, 1, 2, 4, 5, 6, 7, 10, 11], + "value": [ + 0.2, 2.1, -1.1, 0.01, 4.03, 2.37, -2.11, 1.96, -0.98, -1.32 + ] + } + ) + trend = quant_reg(data, 0.5, 50) assert list(trend.keys()) == ["trend", "uncertainty", "p_value"] diff --git a/toarstats/trends/interface.py b/toarstats/trends/interface.py index 9998380f71cdd563de724b6cac1a5801e949ed9a..800c700d02d08c0a99b8355f5e5cf32847ea8a7d 100644 --- a/toarstats/trends/interface.py +++ b/toarstats/trends/interface.py @@ -10,7 +10,7 @@ from toarstats.trends.quant_reg import quant_reg from toarstats.trends.utils import calculate_anomalies -def calculate_trend(method, data, quantiles=None): +def calculate_trend(method, data, quantiles=None, num_samples=1000): """Calculate the trend using the requested method. This function is the public interface for the ``trends`` subpackage. @@ -28,6 +28,8 @@ def calculate_trend(method, data, quantiles=None): :param quantiles: a single quantile or a list of quantiles to calculate, these must be between 0 and 1; only needed when ``method="quant"`` + :param num_samples: number of sampled trends in moving block + bootstrap :raises TypeError: raised if @@ -46,6 +48,7 @@ def calculate_trend(method, data, quantiles=None): - the index and values have different lengths - any ``quantiles`` are not strictly within 0 and 1 with ``method="quantreg"`` + - ``num_samples`` is not a positive integer :return: The result of the fit or a dict of fit results if ``method="quant"`` @@ -62,12 +65,14 @@ def calculate_trend(method, data, quantiles=None): ) if not all(0 < quantile < 1 for quantile in quantile_list): raise ValueError("The quantiles must be strictly between 0 and 1.") + if not isinstance(num_samples, int) or num_samples < 1: + raise ValueError("The number of samples must be a positive integer.") anomalies_series = calculate_anomalies(data_in) if method == "quant": fit = { - quantile: quant_reg(anomalies_series, quantile) + quantile: quant_reg(anomalies_series, quantile, num_samples) for quantile in quantile_list } else: - fit = ols(anomalies_series) + fit = ols(anomalies_series, num_samples) return fit diff --git a/toarstats/trends/ols.py b/toarstats/trends/ols.py index a34cf32f1323e3b07fa6849f4be9169ee6f4d8f9..f598d07568345e8b67921e556a166c9d1ab1d6f8 100644 --- a/toarstats/trends/ols.py +++ b/toarstats/trends/ols.py @@ -11,17 +11,19 @@ import statsmodels.formula.api as smf from toarstats.trends.utils import moving_block_bootstrap -def ols(data): +def ols(data, num_samples): """Calculate the OLS linear regression. :param data: data containing a list of date time values and associated parameter values on which to calculate the trend + :param num_samples: number of sampled trends in moving block + bootstrap :return: The trend with its uncertainty and p value """ fit = smf.ols("value~datetime", data).fit(method="qr").params - mbb = moving_block_bootstrap("OLS", data) + mbb = moving_block_bootstrap("OLS", data, num_samples=num_samples) fit_se = np.nanstd(mbb, axis=0) fit_pv = 2*scipy.stats.t.sf(x=abs(fit/fit_se), df=len(data)-2) return {"trend": fit, "uncertainty": fit_se, "p_value": fit_pv} diff --git a/toarstats/trends/quant_reg.py b/toarstats/trends/quant_reg.py index dd2d926d3c2911912005ad3ffabceac6577740ee..dc9d9d94a23643356e7abe49bc6613b76fe76075 100644 --- a/toarstats/trends/quant_reg.py +++ b/toarstats/trends/quant_reg.py @@ -11,18 +11,20 @@ import statsmodels.formula.api as smf from toarstats.trends.utils import moving_block_bootstrap -def quant_reg(data, quantile): +def quant_reg(data, quantile, num_samples): """Calculate the quantile regression. :param data: data containing a list of date time values and associated parameter values on which to calculate the trend :param quantile: a single quantile, must be between 0 and 1 + :param num_samples: number of sampled trends in moving block + bootstrap :return: The trend with its uncertainty and p value """ fit = smf.quantreg("value~datetime", data).fit(q=quantile).params - mbb = moving_block_bootstrap("quant", data, quantile) + mbb = moving_block_bootstrap("quant", data, quantile, num_samples) fit_se = np.nanstd(mbb, axis=0) fit_pv = 2*scipy.stats.t.sf(x=abs(fit/fit_se), df=len(data)-2) return {"trend": fit, "uncertainty": fit_se, "p_value": fit_pv}