Skip to content
Snippets Groups Projects
Commit 4c0c54db authored by lukas leufen's avatar lukas leufen
Browse files

more tests

parent edc5e931
No related branches found
No related tags found
Loading
Checking pipeline status
...@@ -45,7 +45,7 @@ class Distributor(keras.utils.Sequence): ...@@ -45,7 +45,7 @@ class Distributor(keras.utils.Sequence):
for prev, curr in enumerate(range(1, num_mini_batches+1)): for prev, curr in enumerate(range(1, num_mini_batches+1)):
x = x_total[prev*self.batch_size:curr*self.batch_size, ...] x = x_total[prev*self.batch_size:curr*self.batch_size, ...]
y = [y_total[prev*self.batch_size:curr*self.batch_size, ...] for _ in range(mod_rank)] y = [y_total[prev*self.batch_size:curr*self.batch_size, ...] for _ in range(mod_rank)]
if x is not None: if x is not None: # pragma: no branch
yield (x, y) yield (x, y)
if (k + 1) == len(self.generator) and curr == num_mini_batches and not fit_call: if (k + 1) == len(self.generator) and curr == num_mini_batches and not fit_call:
return return
......
...@@ -75,11 +75,11 @@ class DataGenerator(keras.utils.Sequence): ...@@ -75,11 +75,11 @@ class DataGenerator(keras.utils.Sequence):
if self._iterator < self.__len__(): if self._iterator < self.__len__():
data = self.get_data_generator() data = self.get_data_generator()
self._iterator += 1 self._iterator += 1
if data.history is not None and data.label is not None: if data.history is not None and data.label is not None: # pragma: no branch
return data.history.transpose("datetime", "window", "Stations", "variables"), \ return data.history.transpose("datetime", "window", "Stations", "variables"), \
data.label.squeeze("Stations").transpose("datetime", "window") data.label.squeeze("Stations").transpose("datetime", "window")
else: else:
self.__next__() self.__next__() # pragma: no cover
else: else:
raise StopIteration raise StopIteration
......
...@@ -5,6 +5,24 @@ import keras ...@@ -5,6 +5,24 @@ import keras
import numpy as np import numpy as np
class TestHistoryAdvanced:
def test_init(self):
hist = HistoryAdvanced()
assert hist.validation_data is None
assert hist.model is None
assert isinstance(hist.epoch, list) and len(hist.epoch) == 0
assert isinstance(hist.history, dict) and len(hist.history.keys()) == 0
def test_on_train_begin(self):
hist = HistoryAdvanced()
hist.epoch = [1, 2, 3]
hist.history = {"mse": [10, 7, 4]}
hist.on_train_begin()
assert hist.epoch == [1, 2, 3]
assert hist.history == {"mse": [10, 7, 4]}
class TestLearningRateDecay: class TestLearningRateDecay:
def test_init(self): def test_init(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment