Skip to content
Snippets Groups Projects
Commit 4c0c54db authored by lukas leufen's avatar lukas leufen
Browse files

more tests

parent edc5e931
No related branches found
No related tags found
2 merge requests!37include new development,!33Lukas issue036 feat local temp data storage
Pipeline #29114 passed
......@@ -45,7 +45,7 @@ class Distributor(keras.utils.Sequence):
for prev, curr in enumerate(range(1, num_mini_batches+1)):
x = x_total[prev*self.batch_size:curr*self.batch_size, ...]
y = [y_total[prev*self.batch_size:curr*self.batch_size, ...] for _ in range(mod_rank)]
if x is not None:
if x is not None: # pragma: no branch
yield (x, y)
if (k + 1) == len(self.generator) and curr == num_mini_batches and not fit_call:
return
......
......@@ -75,11 +75,11 @@ class DataGenerator(keras.utils.Sequence):
if self._iterator < self.__len__():
data = self.get_data_generator()
self._iterator += 1
if data.history is not None and data.label is not None:
if data.history is not None and data.label is not None: # pragma: no branch
return data.history.transpose("datetime", "window", "Stations", "variables"), \
data.label.squeeze("Stations").transpose("datetime", "window")
else:
self.__next__()
self.__next__() # pragma: no cover
else:
raise StopIteration
......
......@@ -5,6 +5,24 @@ import keras
import numpy as np
class TestHistoryAdvanced:
def test_init(self):
hist = HistoryAdvanced()
assert hist.validation_data is None
assert hist.model is None
assert isinstance(hist.epoch, list) and len(hist.epoch) == 0
assert isinstance(hist.history, dict) and len(hist.history.keys()) == 0
def test_on_train_begin(self):
hist = HistoryAdvanced()
hist.epoch = [1, 2, 3]
hist.history = {"mse": [10, 7, 4]}
hist.on_train_begin()
assert hist.epoch == [1, 2, 3]
assert hist.history == {"mse": [10, 7, 4]}
class TestLearningRateDecay:
def test_init(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment