diff --git a/src/data_handling/data_preparation.py b/src/data_handling/data_preparation.py index ff3006e341807eb822d12eca7f35a6823bc098a4..e3186778b94375ba1d39fa87ba7d2980c785581e 100644 --- a/src/data_handling/data_preparation.py +++ b/src/data_handling/data_preparation.py @@ -354,7 +354,6 @@ class DataPrep(object): intersect = reduce(np.intersect1d, (non_nan_history.coords[dim].values, non_nan_label.coords[dim].values, non_nan_observation.coords[dim].values)) min_length = self.kwargs.get("min_length", 0) - length = len(intersect) if len(intersect) < max(min_length, 1): self.history = None self.label = None diff --git a/test/test_data_handling/test_data_preparation.py b/test/test_data_handling/test_data_preparation.py index 91719f3dd16326ee6281c4db8ef3aa87e238d70f..85c4420609a466ff5f3eeb3d46cb6bb07fe9c30a 100644 --- a/test/test_data_handling/test_data_preparation.py +++ b/test/test_data_handling/test_data_preparation.py @@ -287,6 +287,14 @@ class TestDataPrep: assert remaining_len == data.label.datetime.shape assert remaining_len == data.observation.datetime.shape + def test_remove_nan_too_short(self, data): + data.kwargs["min_length"] = 4000 # actual length of series is 3940 + data.make_history_window('variables', -12, 'datetime') + data.make_labels('variables', 'o3', 'datetime', 3) + data.make_observation('variables', 'o3', 'datetime') + data.remove_nan('datetime') + assert not any([data.history, data.label, data.observation]) + def test_create_index_array(self, data): index_array = data.create_index_array('window', range(1, 4)) assert np.testing.assert_array_equal(index_array.data, [1, 2, 3]) is None