diff --git a/src/data_handling/data_distributor.py b/src/data_handling/data_distributor.py index 74df5f6ac1c998e644fa7d89a688fc12dee21265..c6f38a6f0e70518956bcbbd51a6fdfc1a1e7849f 100644 --- a/src/data_handling/data_distributor.py +++ b/src/data_handling/data_distributor.py @@ -45,7 +45,7 @@ class Distributor(keras.utils.Sequence): for prev, curr in enumerate(range(1, num_mini_batches+1)): x = x_total[prev*self.batch_size:curr*self.batch_size, ...] y = [y_total[prev*self.batch_size:curr*self.batch_size, ...] for _ in range(mod_rank)] - if x is not None: + if x is not None: # pragma: no branch yield (x, y) if (k + 1) == len(self.generator) and curr == num_mini_batches and not fit_call: return diff --git a/src/data_handling/data_generator.py b/src/data_handling/data_generator.py index 26b12d5955f2dd44661fe1da4450cb113c37b1b7..732a7efdf8f360b49823dfb6ca5ca3239cc774af 100644 --- a/src/data_handling/data_generator.py +++ b/src/data_handling/data_generator.py @@ -75,11 +75,11 @@ class DataGenerator(keras.utils.Sequence): if self._iterator < self.__len__(): data = self.get_data_generator() self._iterator += 1 - if data.history is not None and data.label is not None: + if data.history is not None and data.label is not None: # pragma: no branch return data.history.transpose("datetime", "window", "Stations", "variables"), \ data.label.squeeze("Stations").transpose("datetime", "window") else: - self.__next__() + self.__next__() # pragma: no cover else: raise StopIteration diff --git a/test/test_model_modules/test_keras_extensions.py b/test/test_model_modules/test_keras_extensions.py index c50e5e425779575eb2b492213a0b39b2b7c3376e..7c32844d54e88f61690e65885b8997e98a698ff5 100644 --- a/test/test_model_modules/test_keras_extensions.py +++ b/test/test_model_modules/test_keras_extensions.py @@ -5,6 +5,24 @@ import keras import numpy as np +class TestHistoryAdvanced: + + def test_init(self): + hist = HistoryAdvanced() + assert hist.validation_data is None + assert hist.model is None + assert isinstance(hist.epoch, list) and len(hist.epoch) == 0 + assert isinstance(hist.history, dict) and len(hist.history.keys()) == 0 + + def test_on_train_begin(self): + hist = HistoryAdvanced() + hist.epoch = [1, 2, 3] + hist.history = {"mse": [10, 7, 4]} + hist.on_train_begin() + assert hist.epoch == [1, 2, 3] + assert hist.history == {"mse": [10, 7, 4]} + + class TestLearningRateDecay: def test_init(self):