Skip to content
Snippets Groups Projects
Commit 3b9ee13f authored by Niklas Selke's avatar Niklas Selke
Browse files

Merge branch 'niklas_issue009_feat_calculate-quantile-regression' into 'develop'

Added 'num_samples' parameter to the trends interface.

See merge request !8
parents 514e9f7c 4aa8cba0
No related branches found
No related tags found
2 merge requests!9Develop,!8Added 'num_samples' parameter to the trends interface.
...@@ -18,6 +18,14 @@ def test_calculate_trend_wrong_quantile(mean_data): ...@@ -18,6 +18,14 @@ def test_calculate_trend_wrong_quantile(mean_data):
trend = calculate_trend("quant", mean_data, 1.2) trend = calculate_trend("quant", mean_data, 1.2)
def test_calculate_trend_wrong_num_samples(mean_data):
with pytest.raises(
ValueError,
match="The number of samples must be a positive integer."
):
trend = calculate_trend("quant", mean_data, 0.5, -1)
@pytest.mark.filterwarnings("ignore:Maximum number of iterations") @pytest.mark.filterwarnings("ignore:Maximum number of iterations")
@pytest.mark.slow @pytest.mark.slow
def test_calculate_trend_one_quantile(mean_data): def test_calculate_trend_one_quantile(mean_data):
...@@ -25,6 +33,12 @@ def test_calculate_trend_one_quantile(mean_data): ...@@ -25,6 +33,12 @@ def test_calculate_trend_one_quantile(mean_data):
assert list(trend.keys()) == [0.5] assert list(trend.keys()) == [0.5]
@pytest.mark.filterwarnings("ignore:Maximum number of iterations")
def test_calculate_trend_one_quantile_few_samples(mean_data):
trend = calculate_trend("quant", mean_data, 0.5, 50)
assert list(trend.keys()) == [0.5]
@pytest.mark.filterwarnings("ignore:Maximum number of iterations") @pytest.mark.filterwarnings("ignore:Maximum number of iterations")
@pytest.mark.slow @pytest.mark.slow
def test_calculate_trend_multiple_quantiles(mean_data): def test_calculate_trend_multiple_quantiles(mean_data):
...@@ -32,7 +46,18 @@ def test_calculate_trend_multiple_quantiles(mean_data): ...@@ -32,7 +46,18 @@ def test_calculate_trend_multiple_quantiles(mean_data):
assert list(trends.keys()) == [0.25, 0.5, 0.75] assert list(trends.keys()) == [0.25, 0.5, 0.75]
@pytest.mark.filterwarnings("ignore:Maximum number of iterations")
def test_calculate_trend_multiple_quantiles_few_samples(mean_data):
trends = calculate_trend("quant", mean_data, [0.25, 0.5, 0.75], 50)
assert list(trends.keys()) == [0.25, 0.5, 0.75]
@pytest.mark.slow @pytest.mark.slow
def test_calculate_trend_ols(mean_data): def test_calculate_trend_ols(mean_data):
trend = calculate_trend("OLS", mean_data) trend = calculate_trend("OLS", mean_data)
assert list(trend.keys()) == ["trend", "uncertainty", "p_value"] assert list(trend.keys()) == ["trend", "uncertainty", "p_value"]
def test_calculate_trend_ols_few_samples(mean_data):
trend = calculate_trend("OLS", mean_data, num_samples=50)
assert list(trend.keys()) == ["trend", "uncertainty", "p_value"]
...@@ -14,5 +14,18 @@ def test_ols_sample_data(): ...@@ -14,5 +14,18 @@ def test_ols_sample_data():
] ]
} }
) )
trend = ols(data) trend = ols(data, 1000)
assert list(trend.keys()) == ["trend", "uncertainty", "p_value"]
def test_ols_sample_data_few_samples():
data = pd.DataFrame(
{
"datetime": [-2, 0, 1, 2, 4, 5, 6, 7, 10, 11],
"value": [
0.2, 2.1, -1.1, 0.01, 4.03, 2.37, -2.11, 1.96, -0.98, -1.32
]
}
)
trend = ols(data, 50)
assert list(trend.keys()) == ["trend", "uncertainty", "p_value"] assert list(trend.keys()) == ["trend", "uncertainty", "p_value"]
...@@ -14,5 +14,18 @@ def test_quant_reg_sample_data(): ...@@ -14,5 +14,18 @@ def test_quant_reg_sample_data():
] ]
} }
) )
trend = quant_reg(data, 0.5) trend = quant_reg(data, 0.5, 1000)
assert list(trend.keys()) == ["trend", "uncertainty", "p_value"]
def test_quant_reg_sample_data_few_samples():
data = pd.DataFrame(
{
"datetime": [-2, 0, 1, 2, 4, 5, 6, 7, 10, 11],
"value": [
0.2, 2.1, -1.1, 0.01, 4.03, 2.37, -2.11, 1.96, -0.98, -1.32
]
}
)
trend = quant_reg(data, 0.5, 50)
assert list(trend.keys()) == ["trend", "uncertainty", "p_value"] assert list(trend.keys()) == ["trend", "uncertainty", "p_value"]
...@@ -10,7 +10,7 @@ from toarstats.trends.quant_reg import quant_reg ...@@ -10,7 +10,7 @@ from toarstats.trends.quant_reg import quant_reg
from toarstats.trends.utils import calculate_anomalies from toarstats.trends.utils import calculate_anomalies
def calculate_trend(method, data, quantiles=None): def calculate_trend(method, data, quantiles=None, num_samples=1000):
"""Calculate the trend using the requested method. """Calculate the trend using the requested method.
This function is the public interface for the ``trends`` subpackage. This function is the public interface for the ``trends`` subpackage.
...@@ -28,6 +28,8 @@ def calculate_trend(method, data, quantiles=None): ...@@ -28,6 +28,8 @@ def calculate_trend(method, data, quantiles=None):
:param quantiles: a single quantile or a list of quantiles to :param quantiles: a single quantile or a list of quantiles to
calculate, these must be between 0 and 1; only calculate, these must be between 0 and 1; only
needed when ``method="quant"`` needed when ``method="quant"``
:param num_samples: number of sampled trends in moving block
bootstrap
:raises TypeError: raised if :raises TypeError: raised if
...@@ -46,6 +48,7 @@ def calculate_trend(method, data, quantiles=None): ...@@ -46,6 +48,7 @@ def calculate_trend(method, data, quantiles=None):
- the index and values have different lengths - the index and values have different lengths
- any ``quantiles`` are not strictly within 0 - any ``quantiles`` are not strictly within 0
and 1 with ``method="quantreg"`` and 1 with ``method="quantreg"``
- ``num_samples`` is not a positive integer
:return: The result of the fit or a dict of fit results if :return: The result of the fit or a dict of fit results if
``method="quant"`` ``method="quant"``
...@@ -62,12 +65,14 @@ def calculate_trend(method, data, quantiles=None): ...@@ -62,12 +65,14 @@ def calculate_trend(method, data, quantiles=None):
) )
if not all(0 < quantile < 1 for quantile in quantile_list): if not all(0 < quantile < 1 for quantile in quantile_list):
raise ValueError("The quantiles must be strictly between 0 and 1.") raise ValueError("The quantiles must be strictly between 0 and 1.")
if not isinstance(num_samples, int) or num_samples < 1:
raise ValueError("The number of samples must be a positive integer.")
anomalies_series = calculate_anomalies(data_in) anomalies_series = calculate_anomalies(data_in)
if method == "quant": if method == "quant":
fit = { fit = {
quantile: quant_reg(anomalies_series, quantile) quantile: quant_reg(anomalies_series, quantile, num_samples)
for quantile in quantile_list for quantile in quantile_list
} }
else: else:
fit = ols(anomalies_series) fit = ols(anomalies_series, num_samples)
return fit return fit
...@@ -11,17 +11,19 @@ import statsmodels.formula.api as smf ...@@ -11,17 +11,19 @@ import statsmodels.formula.api as smf
from toarstats.trends.utils import moving_block_bootstrap from toarstats.trends.utils import moving_block_bootstrap
def ols(data): def ols(data, num_samples):
"""Calculate the OLS linear regression. """Calculate the OLS linear regression.
:param data: data containing a list of date time values and :param data: data containing a list of date time values and
associated parameter values on which to calculate the associated parameter values on which to calculate the
trend trend
:param num_samples: number of sampled trends in moving block
bootstrap
:return: The trend with its uncertainty and p value :return: The trend with its uncertainty and p value
""" """
fit = smf.ols("value~datetime", data).fit(method="qr").params fit = smf.ols("value~datetime", data).fit(method="qr").params
mbb = moving_block_bootstrap("OLS", data) mbb = moving_block_bootstrap("OLS", data, num_samples=num_samples)
fit_se = np.nanstd(mbb, axis=0) fit_se = np.nanstd(mbb, axis=0)
fit_pv = 2*scipy.stats.t.sf(x=abs(fit/fit_se), df=len(data)-2) fit_pv = 2*scipy.stats.t.sf(x=abs(fit/fit_se), df=len(data)-2)
return {"trend": fit, "uncertainty": fit_se, "p_value": fit_pv} return {"trend": fit, "uncertainty": fit_se, "p_value": fit_pv}
...@@ -11,18 +11,20 @@ import statsmodels.formula.api as smf ...@@ -11,18 +11,20 @@ import statsmodels.formula.api as smf
from toarstats.trends.utils import moving_block_bootstrap from toarstats.trends.utils import moving_block_bootstrap
def quant_reg(data, quantile): def quant_reg(data, quantile, num_samples):
"""Calculate the quantile regression. """Calculate the quantile regression.
:param data: data containing a list of date time values and :param data: data containing a list of date time values and
associated parameter values on which to calculate the associated parameter values on which to calculate the
trend trend
:param quantile: a single quantile, must be between 0 and 1 :param quantile: a single quantile, must be between 0 and 1
:param num_samples: number of sampled trends in moving block
bootstrap
:return: The trend with its uncertainty and p value :return: The trend with its uncertainty and p value
""" """
fit = smf.quantreg("value~datetime", data).fit(q=quantile).params fit = smf.quantreg("value~datetime", data).fit(q=quantile).params
mbb = moving_block_bootstrap("quant", data, quantile) mbb = moving_block_bootstrap("quant", data, quantile, num_samples)
fit_se = np.nanstd(mbb, axis=0) fit_se = np.nanstd(mbb, axis=0)
fit_pv = 2*scipy.stats.t.sf(x=abs(fit/fit_se), df=len(data)-2) fit_pv = 2*scipy.stats.t.sf(x=abs(fit/fit_se), df=len(data)-2)
return {"trend": fit, "uncertainty": fit_se, "p_value": fit_pv} return {"trend": fit, "uncertainty": fit_se, "p_value": fit_pv}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment