Skip to content
Snippets Groups Projects
Commit 3b9ee13f authored by Niklas Selke's avatar Niklas Selke
Browse files

Merge branch 'niklas_issue009_feat_calculate-quantile-regression' into 'develop'

Added 'num_samples' parameter to the trends interface.

See merge request !8
parents 514e9f7c 4aa8cba0
Branches
Tags
2 merge requests!9Develop,!8Added 'num_samples' parameter to the trends interface.
......@@ -18,6 +18,14 @@ def test_calculate_trend_wrong_quantile(mean_data):
trend = calculate_trend("quant", mean_data, 1.2)
def test_calculate_trend_wrong_num_samples(mean_data):
with pytest.raises(
ValueError,
match="The number of samples must be a positive integer."
):
trend = calculate_trend("quant", mean_data, 0.5, -1)
@pytest.mark.filterwarnings("ignore:Maximum number of iterations")
@pytest.mark.slow
def test_calculate_trend_one_quantile(mean_data):
......@@ -25,6 +33,12 @@ def test_calculate_trend_one_quantile(mean_data):
assert list(trend.keys()) == [0.5]
@pytest.mark.filterwarnings("ignore:Maximum number of iterations")
def test_calculate_trend_one_quantile_few_samples(mean_data):
trend = calculate_trend("quant", mean_data, 0.5, 50)
assert list(trend.keys()) == [0.5]
@pytest.mark.filterwarnings("ignore:Maximum number of iterations")
@pytest.mark.slow
def test_calculate_trend_multiple_quantiles(mean_data):
......@@ -32,7 +46,18 @@ def test_calculate_trend_multiple_quantiles(mean_data):
assert list(trends.keys()) == [0.25, 0.5, 0.75]
@pytest.mark.filterwarnings("ignore:Maximum number of iterations")
def test_calculate_trend_multiple_quantiles_few_samples(mean_data):
trends = calculate_trend("quant", mean_data, [0.25, 0.5, 0.75], 50)
assert list(trends.keys()) == [0.25, 0.5, 0.75]
@pytest.mark.slow
def test_calculate_trend_ols(mean_data):
trend = calculate_trend("OLS", mean_data)
assert list(trend.keys()) == ["trend", "uncertainty", "p_value"]
def test_calculate_trend_ols_few_samples(mean_data):
trend = calculate_trend("OLS", mean_data, num_samples=50)
assert list(trend.keys()) == ["trend", "uncertainty", "p_value"]
......@@ -14,5 +14,18 @@ def test_ols_sample_data():
]
}
)
trend = ols(data)
trend = ols(data, 1000)
assert list(trend.keys()) == ["trend", "uncertainty", "p_value"]
def test_ols_sample_data_few_samples():
data = pd.DataFrame(
{
"datetime": [-2, 0, 1, 2, 4, 5, 6, 7, 10, 11],
"value": [
0.2, 2.1, -1.1, 0.01, 4.03, 2.37, -2.11, 1.96, -0.98, -1.32
]
}
)
trend = ols(data, 50)
assert list(trend.keys()) == ["trend", "uncertainty", "p_value"]
......@@ -14,5 +14,18 @@ def test_quant_reg_sample_data():
]
}
)
trend = quant_reg(data, 0.5)
trend = quant_reg(data, 0.5, 1000)
assert list(trend.keys()) == ["trend", "uncertainty", "p_value"]
def test_quant_reg_sample_data_few_samples():
data = pd.DataFrame(
{
"datetime": [-2, 0, 1, 2, 4, 5, 6, 7, 10, 11],
"value": [
0.2, 2.1, -1.1, 0.01, 4.03, 2.37, -2.11, 1.96, -0.98, -1.32
]
}
)
trend = quant_reg(data, 0.5, 50)
assert list(trend.keys()) == ["trend", "uncertainty", "p_value"]
......@@ -10,7 +10,7 @@ from toarstats.trends.quant_reg import quant_reg
from toarstats.trends.utils import calculate_anomalies
def calculate_trend(method, data, quantiles=None):
def calculate_trend(method, data, quantiles=None, num_samples=1000):
"""Calculate the trend using the requested method.
This function is the public interface for the ``trends`` subpackage.
......@@ -28,6 +28,8 @@ def calculate_trend(method, data, quantiles=None):
:param quantiles: a single quantile or a list of quantiles to
calculate, these must be between 0 and 1; only
needed when ``method="quant"``
:param num_samples: number of sampled trends in moving block
bootstrap
:raises TypeError: raised if
......@@ -46,6 +48,7 @@ def calculate_trend(method, data, quantiles=None):
- the index and values have different lengths
- any ``quantiles`` are not strictly within 0
and 1 with ``method="quantreg"``
- ``num_samples`` is not a positive integer
:return: The result of the fit or a dict of fit results if
``method="quant"``
......@@ -62,12 +65,14 @@ def calculate_trend(method, data, quantiles=None):
)
if not all(0 < quantile < 1 for quantile in quantile_list):
raise ValueError("The quantiles must be strictly between 0 and 1.")
if not isinstance(num_samples, int) or num_samples < 1:
raise ValueError("The number of samples must be a positive integer.")
anomalies_series = calculate_anomalies(data_in)
if method == "quant":
fit = {
quantile: quant_reg(anomalies_series, quantile)
quantile: quant_reg(anomalies_series, quantile, num_samples)
for quantile in quantile_list
}
else:
fit = ols(anomalies_series)
fit = ols(anomalies_series, num_samples)
return fit
......@@ -11,17 +11,19 @@ import statsmodels.formula.api as smf
from toarstats.trends.utils import moving_block_bootstrap
def ols(data):
def ols(data, num_samples):
"""Calculate the OLS linear regression.
:param data: data containing a list of date time values and
associated parameter values on which to calculate the
trend
:param num_samples: number of sampled trends in moving block
bootstrap
:return: The trend with its uncertainty and p value
"""
fit = smf.ols("value~datetime", data).fit(method="qr").params
mbb = moving_block_bootstrap("OLS", data)
mbb = moving_block_bootstrap("OLS", data, num_samples=num_samples)
fit_se = np.nanstd(mbb, axis=0)
fit_pv = 2*scipy.stats.t.sf(x=abs(fit/fit_se), df=len(data)-2)
return {"trend": fit, "uncertainty": fit_se, "p_value": fit_pv}
......@@ -11,18 +11,20 @@ import statsmodels.formula.api as smf
from toarstats.trends.utils import moving_block_bootstrap
def quant_reg(data, quantile):
def quant_reg(data, quantile, num_samples):
"""Calculate the quantile regression.
:param data: data containing a list of date time values and
associated parameter values on which to calculate the
trend
:param quantile: a single quantile, must be between 0 and 1
:param num_samples: number of sampled trends in moving block
bootstrap
:return: The trend with its uncertainty and p value
"""
fit = smf.quantreg("value~datetime", data).fit(q=quantile).params
mbb = moving_block_bootstrap("quant", data, quantile)
mbb = moving_block_bootstrap("quant", data, quantile, num_samples)
fit_se = np.nanstd(mbb, axis=0)
fit_pv = 2*scipy.stats.t.sf(x=abs(fit/fit_se), df=len(data)-2)
return {"trend": fit, "uncertainty": fit_se, "p_value": fit_pv}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment