From 786f0c8cf08e7d0102ed880b11dee01d96e304d6 Mon Sep 17 00:00:00 2001 From: leufen1 <l.leufen@fz-juelich.de> Date: Tue, 31 May 2022 15:54:15 +0200 Subject: [PATCH] update dummy data handler --- test/test_data_handler/test_iterator.py | 3 +++ test/test_run_modules/test_model_setup.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/test/test_data_handler/test_iterator.py b/test/test_data_handler/test_iterator.py index e47d725a..bb8ecb5d 100644 --- a/test/test_data_handler/test_iterator.py +++ b/test/test_data_handler/test_iterator.py @@ -106,6 +106,9 @@ class DummyData: Y2 = np.random.randint(21, 30, size=(self.number_of_samples, 5, 1)) # samples, window, variables return [Y1, Y2] + def get_data(self, upsampling=False, as_numpy=True): + return self.get_X(upsampling, as_numpy), self.get_Y(upsampling, as_numpy) + class TestKerasIterator: diff --git a/test/test_run_modules/test_model_setup.py b/test/test_run_modules/test_model_setup.py index 962287e0..6e8d3ea9 100644 --- a/test/test_run_modules/test_model_setup.py +++ b/test/test_run_modules/test_model_setup.py @@ -150,3 +150,6 @@ class DummyData: Y1 = np.random.randint(0, 10, size=(self.number_of_samples, 5)) # samples, window Y2 = np.random.randint(21, 30, size=(self.number_of_samples, 3)) # samples, window return [Y1, Y2] + + def get_data(self, upsampling=False, as_numpy=True): + return self.get_X(upsampling, as_numpy), self.get_Y(upsampling, as_numpy) -- GitLab