diff --git a/tests/test_metrics/test_metrics.py b/tests/test_metrics/test_metrics.py index d7b68939d0b0b28d4aa6687306dfb9d8e6408612..e5020b5f997d16b4573a59b139cadb18338da6c5 100644 --- a/tests/test_metrics/test_metrics.py +++ b/tests/test_metrics/test_metrics.py @@ -1,20 +1,3 @@ -"""Tests for the metrics subpackage as a whole. - -This module contains tests to check if everything from older package -versions is implemented and if the results are still the same. - -This module contains the following functions: -create_sample_data - create sample data -get_all_statistics - get all implemented statistics -get_all_samplings - get all implemented samplings -sample_data - get sample data frame -sample_metadata - get sample metadata - -This module contains the following tests: -test_all_statistics_and_samplings_from_old_versions_implemented -test_results_match_reference_results -""" - import ast from collections import namedtuple from configparser import ConfigParser @@ -29,10 +12,6 @@ from toarstats.metrics.stats_utils import STATS_LIST def create_sample_data(sample_data_dir): - """Create sample data. - - :param sample_data_dir: path to the sample data directory - """ sample_data_dir.mkdir(exist_ok=True) datetime_index = pd.date_range(start="2011-04-17 09:00", periods=100000, freq="H") @@ -53,18 +32,10 @@ def create_sample_data(sample_data_dir): def get_all_statistics(): - """Get all implemented statistics. - - :return: A set of all implemented statistics - """ return set(STATS_LIST) def get_all_samplings(): - """Get all implemented samplings. - - :return: A set of all implemented samplings - """ samplings = set() for file in Path(Path(__file__).resolve().parents[2], "toarstats/metrics").glob("*.py"): @@ -81,10 +52,6 @@ def get_all_samplings(): @pytest.fixture def sample_data(): - """Get sample data frame. - - :return: A data frame with sample data - """ sample_data_file = Path( Path(__file__).resolve().parent, "sample_data/sample_data.csv" ) @@ -98,10 +65,6 @@ def sample_data(): @pytest.fixture def sample_metadata(): - """Get sample metadata. - - :return: A ``namedtuple`` with the sample metadata information - """ parser = ConfigParser() with open(Path(Path(__file__).resolve().parent, "sample_data/sample_metadata.cfg"), @@ -115,7 +78,6 @@ def sample_metadata(): def test_all_statistics_and_samplings_from_old_versions_implemented(): - """Test if all old statistics and samplings are implemented.""" old_statistics = set() old_samplings = set() for file in Path(Path(__file__).resolve().parent, @@ -131,7 +93,6 @@ def test_all_statistics_and_samplings_from_old_versions_implemented(): @pytest.mark.parametrize("statistic", sorted(get_all_statistics())) def test_results_match_reference_results(statistic, sampling, sample_data, sample_metadata): - """Test if the results match old package versions.""" if statistic == "drmdmax1h" and sampling == "monthly": with pytest.raises(ValueError, match="The drmdmax1h statistic cannot" " be evaluated with monthly" diff --git a/tests/test_metrics/test_stats_utils.py b/tests/test_metrics/test_stats_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0970043f582cd472970935849d2abe0b6180b8a0 --- /dev/null +++ b/tests/test_metrics/test_stats_utils.py @@ -0,0 +1,48 @@ +import numpy as np +import pandas as pd +import pytest + +from toarstats.metrics.stats_utils import ( + get_growing_season, get_seasons, harmonize_time, kth_highest, prepare_data +) + + +def test_get_growing_season(): + with pytest.raises( + KeyError, + match="wheat-cool_temperate_le30-SH not included in SEASON_DICT" + ): + growing_season = get_growing_season("wheat", "cool temperate", -15) + + +def test_get_seasons(): + seasons = get_seasons( + None, ["mean", "median"], None, ["DJF", "MAM"], None, True, False + ) + assert seasons == [["DJF", "MAM"], ["DJF", "MAM"]] + + +def test_harmonize_time(): + record_list = harmonize_time([{"ser": pd.Series()}], "monthly") + for record in record_list: + pd.testing.assert_index_equal( + record["ser"].index, pd.DatetimeIndex([]) + ) + + +def test_kth_highest(): + ser = kth_highest( + pd.Series(index=pd.DatetimeIndex([])), None, "annual", 26, 0.75 + ) + pd.testing.assert_series_equal( + ser, pd.Series(index=pd.DatetimeIndex([], ser.index.freq), dtype=float) + ) + + +def test_prepare_data(): + record = prepare_data( + pd.Series(index=pd.DatetimeIndex([])), + pd.Series(index=pd.DatetimeIndex(["2000"])), + "seasonal", ["DJF"], "mean" + ) + assert record[0]["ser"].index == pd.DatetimeIndex(["2000"]) diff --git a/toarstats/metrics/input_checks.py b/toarstats/metrics/input_checks.py index 3bd3d179e53e523c3a958887df4fbd18cf5276d8..37365b33701de513abaa1900b62ac074aa38ace9 100644 --- a/toarstats/metrics/input_checks.py +++ b/toarstats/metrics/input_checks.py @@ -228,7 +228,9 @@ def check_data(data_in, datetimes_in, values_in): check_index(index_out) check_values(values_out) if index_out.size != values_out.size: - raise ValueError("Datetime index and values must have the same length") + raise ValueError( + "Datetime index and values must have the same length" + ) # pragma: no cover if index_out.tz: index_out = index_out.tz_localize(None) return pd.Series(values_out, index_out) diff --git a/toarstats/metrics/ozone_metrics.py b/toarstats/metrics/ozone_metrics.py index 23ae283c5a19b6fc73228cf5ac49a72ef07bb107..6bdcec77449d6d1ad4a2de0c880b5c43f8f4c92e 100644 --- a/toarstats/metrics/ozone_metrics.py +++ b/toarstats/metrics/ozone_metrics.py @@ -278,7 +278,7 @@ def drmdmax1h(ser, ref, mtype, metadata, seasons, min_data_capture): tmp1["ser"] = tmp2 res = [tmp1] elif mtype == "monthly": - raise ValueError("Invalid mtype") + raise ValueError("Invalid mtype") # pragma: no cover else: tmpres = stat_processor_1("drmdmax1h", tmp2, tmp1ref, mtype, seasons, func=resample_with_date) diff --git a/toarstats/metrics/stats_utils.py b/toarstats/metrics/stats_utils.py index b80f60a7eff867e7ea71868b385666491de765c7..c6528cf646fc57fe40026f888760af1e5fbf9047 100644 --- a/toarstats/metrics/stats_utils.py +++ b/toarstats/metrics/stats_utils.py @@ -399,7 +399,7 @@ def resample(ser, ref, sampling, how, mincount=0, minfrac=None, else: try: ser_tmp = getattr(ser_resample, how)() - except AttributeError: + except AttributeError: # pragma: no cover ser_tmp = ser_resample.apply(how) fcov = (count / ref.resample(sampling).count() if minfrac is not None or how == "sum" else None)