From 6fc65e002a01f8dc843c529917ff9228eb828d50 Mon Sep 17 00:00:00 2001
From: Niklas Selke <n.selke@fz-juelich.de>
Date: Tue, 30 May 2023 16:18:06 +0200
Subject: [PATCH] Added additional tests for the 'metrics' subpackage.

---
 tests/test_metrics/test_metrics.py     | 39 ---------------------
 tests/test_metrics/test_stats_utils.py | 48 ++++++++++++++++++++++++++
 toarstats/metrics/input_checks.py      |  4 ++-
 toarstats/metrics/ozone_metrics.py     |  2 +-
 toarstats/metrics/stats_utils.py       |  2 +-
 5 files changed, 53 insertions(+), 42 deletions(-)
 create mode 100644 tests/test_metrics/test_stats_utils.py

diff --git a/tests/test_metrics/test_metrics.py b/tests/test_metrics/test_metrics.py
index d7b6893..e5020b5 100644
--- a/tests/test_metrics/test_metrics.py
+++ b/tests/test_metrics/test_metrics.py
@@ -1,20 +1,3 @@
-"""Tests for the metrics subpackage as a whole.
-
-This module contains tests to check if everything from older package
-versions is implemented and if the results are still the same.
-
-This module contains the following functions:
-create_sample_data - create sample data
-get_all_statistics - get all implemented statistics
-get_all_samplings - get all implemented samplings
-sample_data - get sample data frame
-sample_metadata - get sample metadata
-
-This module contains the following tests:
-test_all_statistics_and_samplings_from_old_versions_implemented
-test_results_match_reference_results
-"""
-
 import ast
 from collections import namedtuple
 from configparser import ConfigParser
@@ -29,10 +12,6 @@ from toarstats.metrics.stats_utils import STATS_LIST
 
 
 def create_sample_data(sample_data_dir):
-    """Create sample data.
-
-    :param sample_data_dir: path to the sample data directory
-    """
     sample_data_dir.mkdir(exist_ok=True)
     datetime_index = pd.date_range(start="2011-04-17 09:00", periods=100000,
                                    freq="H")
@@ -53,18 +32,10 @@ def create_sample_data(sample_data_dir):
 
 
 def get_all_statistics():
-    """Get all implemented statistics.
-
-    :return: A set of all implemented statistics
-    """
     return set(STATS_LIST)
 
 
 def get_all_samplings():
-    """Get all implemented samplings.
-
-    :return: A set of all implemented samplings
-    """
     samplings = set()
     for file in Path(Path(__file__).resolve().parents[2],
                      "toarstats/metrics").glob("*.py"):
@@ -81,10 +52,6 @@ def get_all_samplings():
 
 @pytest.fixture
 def sample_data():
-    """Get sample data frame.
-
-    :return: A data frame with sample data
-    """
     sample_data_file = Path(
         Path(__file__).resolve().parent, "sample_data/sample_data.csv"
     )
@@ -98,10 +65,6 @@ def sample_data():
 
 @pytest.fixture
 def sample_metadata():
-    """Get sample metadata.
-
-    :return: A ``namedtuple`` with the sample metadata information
-    """
     parser = ConfigParser()
     with open(Path(Path(__file__).resolve().parent,
                    "sample_data/sample_metadata.cfg"),
@@ -115,7 +78,6 @@ def sample_metadata():
 
 
 def test_all_statistics_and_samplings_from_old_versions_implemented():
-    """Test if all old statistics and samplings are implemented."""
     old_statistics = set()
     old_samplings = set()
     for file in Path(Path(__file__).resolve().parent,
@@ -131,7 +93,6 @@ def test_all_statistics_and_samplings_from_old_versions_implemented():
 @pytest.mark.parametrize("statistic", sorted(get_all_statistics()))
 def test_results_match_reference_results(statistic, sampling, sample_data,
                                          sample_metadata):
-    """Test if the results match old package versions."""
     if statistic == "drmdmax1h" and sampling == "monthly":
         with pytest.raises(ValueError, match="The drmdmax1h statistic cannot"
                                              " be evaluated with monthly"
diff --git a/tests/test_metrics/test_stats_utils.py b/tests/test_metrics/test_stats_utils.py
new file mode 100644
index 0000000..0970043
--- /dev/null
+++ b/tests/test_metrics/test_stats_utils.py
@@ -0,0 +1,48 @@
+import numpy as np
+import pandas as pd
+import pytest
+
+from toarstats.metrics.stats_utils import (
+    get_growing_season, get_seasons, harmonize_time, kth_highest, prepare_data
+)
+
+
+def test_get_growing_season():
+    with pytest.raises(
+            KeyError,
+            match="wheat-cool_temperate_le30-SH not included in SEASON_DICT"
+    ):
+        growing_season = get_growing_season("wheat", "cool temperate", -15)
+
+
+def test_get_seasons():
+    seasons = get_seasons(
+        None, ["mean", "median"], None, ["DJF", "MAM"], None, True, False
+    )
+    assert seasons == [["DJF", "MAM"], ["DJF", "MAM"]]
+
+
+def test_harmonize_time():
+    record_list = harmonize_time([{"ser": pd.Series()}], "monthly")
+    for record in record_list:
+        pd.testing.assert_index_equal(
+            record["ser"].index, pd.DatetimeIndex([])
+        )
+
+
+def test_kth_highest():
+    ser = kth_highest(
+        pd.Series(index=pd.DatetimeIndex([])), None, "annual", 26, 0.75
+    )
+    pd.testing.assert_series_equal(
+        ser, pd.Series(index=pd.DatetimeIndex([], ser.index.freq), dtype=float)
+    )
+
+
+def test_prepare_data():
+    record = prepare_data(
+        pd.Series(index=pd.DatetimeIndex([])),
+        pd.Series(index=pd.DatetimeIndex(["2000"])),
+        "seasonal", ["DJF"], "mean"
+    )
+    assert record[0]["ser"].index == pd.DatetimeIndex(["2000"])
diff --git a/toarstats/metrics/input_checks.py b/toarstats/metrics/input_checks.py
index 3bd3d17..37365b3 100644
--- a/toarstats/metrics/input_checks.py
+++ b/toarstats/metrics/input_checks.py
@@ -228,7 +228,9 @@ def check_data(data_in, datetimes_in, values_in):
     check_index(index_out)
     check_values(values_out)
     if index_out.size != values_out.size:
-        raise ValueError("Datetime index and values must have the same length")
+        raise ValueError(
+            "Datetime index and values must have the same length"
+        )  # pragma: no cover
     if index_out.tz:
         index_out = index_out.tz_localize(None)
     return pd.Series(values_out, index_out)
diff --git a/toarstats/metrics/ozone_metrics.py b/toarstats/metrics/ozone_metrics.py
index 23ae283..6bdcec7 100644
--- a/toarstats/metrics/ozone_metrics.py
+++ b/toarstats/metrics/ozone_metrics.py
@@ -278,7 +278,7 @@ def drmdmax1h(ser, ref, mtype, metadata, seasons, min_data_capture):
         tmp1["ser"] = tmp2
         res = [tmp1]
     elif mtype == "monthly":
-        raise ValueError("Invalid mtype")
+        raise ValueError("Invalid mtype")  # pragma: no cover
     else:
         tmpres = stat_processor_1("drmdmax1h", tmp2, tmp1ref, mtype, seasons,
                                   func=resample_with_date)
diff --git a/toarstats/metrics/stats_utils.py b/toarstats/metrics/stats_utils.py
index b80f60a..c6528cf 100644
--- a/toarstats/metrics/stats_utils.py
+++ b/toarstats/metrics/stats_utils.py
@@ -399,7 +399,7 @@ def resample(ser, ref, sampling, how, mincount=0, minfrac=None,
     else:
         try:
             ser_tmp = getattr(ser_resample, how)()
-        except AttributeError:
+        except AttributeError:  # pragma: no cover
             ser_tmp = ser_resample.apply(how)
     fcov = (count / ref.resample(sampling).count()
             if minfrac is not None or how == "sum" else None)
-- 
GitLab