Skip to content
Snippets Groups Projects
Commit a4d56c37 authored by lukas leufen's avatar lukas leufen
Browse files

add tests for data preparation

parent b0e27b1c
Branches
Tags
3 merge requests!90WIP: new release update,!89Resolve "release branch / CI on gpu",!77Resolve "Upsample "extremes" in standardised data space"
Pipeline #32205 passed
......@@ -443,15 +443,14 @@ class DataPrep(object):
"""
# check type if inputs
extreme_values = helpers.to_list(extreme_values)
extreme_values.sort()
for i in extreme_values:
if not isinstance(i, number.__args__):
raise TypeError(f"Elements of list extreme_values have to be {number.__args__}, but at least element "
f"{i} is type {type(i)}")
for extr_val in extreme_values:
for extr_val in sorted(extreme_values):
# check if some extreme values are already extracted
if not all([self.extremes_labels, self.extremes_history]):
if (self.extremes_labels is None) or (self.extremes_history is None):
# extract extremes based on occurance in labels
if extremes_on_right_tail_only:
extreme_label_idx = (self.label > extr_val).any(axis=0).values.reshape(-1,)
......
import datetime as dt
import os
from operator import itemgetter
from operator import itemgetter, lt, gt
import logging
import numpy as np
......@@ -403,3 +403,44 @@ class TestDataPrep:
data.make_labels("variables", "o3", "datetime", 2)
transposed = data.get_transposed_label()
assert transposed.coords.dims == ("datetime", "window")
def test_multiply_extremes(self, data):
data.transform("datetime")
data.make_history_window("variables", 3, "datetime")
data.make_labels("variables", "o3", "datetime", 2)
orig = data.label
data.multiply_extremes(1)
upsampled = data.extremes_labels
assert (upsampled > 1).sum() == (orig > 1).sum()
assert (upsampled < -1).sum() == (orig < -1).sum()
def test_multiply_extremes_from_list(self, data):
data.transform("datetime")
data.make_history_window("variables", 3, "datetime")
data.make_labels("variables", "o3", "datetime", 2)
orig = data.label
data.multiply_extremes([1, 1.5, 2, 3])
upsampled = data.extremes_labels
def f(d, op, n):
return op(d, n).any(dim="window").sum()
assert f(upsampled, gt, 1) == sum([f(orig, gt, 1), f(orig, gt, 1.5), f(orig, gt, 2) * 2, f(orig, gt, 3) * 4])
assert f(upsampled, lt, -1) == sum([f(orig, lt, -1), f(orig, lt, -1.5), f(orig, lt, -2) * 2, f(orig, lt, -3) * 4])
def test_multiply_extremes_wrong_extremes(self, data):
with pytest.raises(TypeError) as e:
data.multiply_extremes([1, "1.5", 2])
assert "Elements of list extreme_values have to be (<class 'float'>, <class 'int'>), but at least element 1.5" \
" is type <class 'str'>" in e.value.args[0]
def test_multiply_extremes_right_tail(self, data):
data.transform("datetime")
data.make_history_window("variables", 3, "datetime")
data.make_labels("variables", "o3", "datetime", 2)
orig = data.label
data.multiply_extremes([1, 2], extremes_on_right_tail_only=True)
upsampled = data.extremes_labels
def f(d, op, n):
return op(d, n).any(dim="window").sum()
assert f(upsampled, gt, 1) == sum([f(orig, gt, 1), f(orig, gt, 2)])
assert len(upsampled) == sum([f(orig, gt, 1), f(orig, gt, 2)])
assert f(upsampled, lt, -1) == 0
......@@ -54,7 +54,8 @@ class TestPreProcessing:
assert obj_with_exp_setup.data_store.search_name("generator") == []
obj_with_exp_setup.split_train_val_test()
data_store = obj_with_exp_setup.data_store
expected_params = ["generator", "start", "end", "stations", "permute_data", "min_length"]
expected_params = ["generator", "start", "end", "stations", "permute_data", "min_length", "extreme_values",
"extremes_on_right_tail_only", "upsampling"]
assert data_store.search_scope("general.train") == sorted(expected_params)
assert data_store.search_name("generator") == sorted(["general.train", "general.val", "general.test",
"general.train_val"])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment