diff --git a/requirements.txt b/requirements.txt
index 9cd9ea44c3cd0068c985c52b07a7cfaa746d9b7c..e7c2f439966f6b085348af3078c814c7f0511024 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -43,6 +43,7 @@ pytest-cov==2.8.1
 pytest-html==2.0.1
 pytest-lazy-fixture==0.6.3
 pytest-metadata==1.8.0
+pytest-sugar
 python-dateutil==2.8.1
 pytz==2019.3
 PyYAML==5.3
diff --git a/requirements_gpu.txt b/requirements_gpu.txt
index 8e5a31e476e47b17d3f271199bbc151fc0dc0b50..9d1c2d62da0864d2626c7ada1aac4dcf6f633630 100644
--- a/requirements_gpu.txt
+++ b/requirements_gpu.txt
@@ -43,6 +43,7 @@ pytest-cov==2.8.1
 pytest-html==2.0.1
 pytest-lazy-fixture==0.6.3
 pytest-metadata==1.8.0
+pytest-sugar
 python-dateutil==2.8.1
 pytz==2019.3
 PyYAML==5.3
diff --git a/src/helpers.py b/src/helpers.py
index d108f3c30bbbe55965d3302d94571f740378503d..07e7e5dde3e20bcd016651cbd47d24970d38303d 100644
--- a/src/helpers.py
+++ b/src/helpers.py
@@ -166,6 +166,28 @@ class PyTestRegex:
         return self._regex.pattern
 
 
+class PyTestAllEqual:
+
+    def __init__(self, check_list):
+        self._list = check_list
+
+    def _check_all_equal(self):
+        equal = True
+        for b in self._list:
+            equal *= xr.testing.assert_equal(self._list[0], b) is None
+        return equal == 1
+
+    def is_true(self):
+        return self._check_all_equal()
+
+
+def xr_all_equal(check_list):
+    equal = True
+    for b in check_list:
+        equal *= xr.testing.assert_equal(check_list[0], b) is None
+    return equal == 1
+
+
 def dict_to_xarray(d: Dict, coordinate_name: str) -> xr.DataArray:
     """
     Convert a dictionary of 2D-xarrays to single 3D-xarray. The name of new coordinate axis follows <coordinate_name>.
diff --git a/test/test_data_handling/test_bootstraps.py b/test/test_data_handling/test_bootstraps.py
index 27494ee9918e1509787a2259cd07976627cb2b18..56b2a8bf9d4ce0d54c271e634b1bb8e171c80a6b 100644
--- a/test/test_data_handling/test_bootstraps.py
+++ b/test/test_data_handling/test_bootstraps.py
@@ -1,6 +1,7 @@
 
 from src.data_handling.bootstraps import BootStraps, BootStrapGenerator
 from src.data_handling.data_generator import DataGenerator
+from src.helpers import PyTestAllEqual, xr_all_equal
 
 import os
 import pytest
@@ -81,6 +82,8 @@ class TestBootstrapGenerator:
         dummy_content = xr.DataArray([1, 2, 3], dims="dummy")
         dummy_content.to_netcdf(os.path.join(path, "DEBW107_o3_temp_hist7_nboots20_shuffled.nc"))
         dummy_content.to_netcdf(os.path.join(path, "DEBW013_o3_temp_hist7_nboots20_shuffled.nc"))
+        dummy_content = dummy_content.expand_dims({"type": ["CNN"]})
+        dummy_content.to_netcdf(os.path.join(path, "forecasts_norm_DEBW107_test.nc"))
         return BootStrapGenerator(orig_generator, 20, path)
 
     def test_init(self, orig_generator):
@@ -114,17 +117,15 @@ class TestBootstrapGenerator:
             res.append(label)
         assert len(res) == boot_gen.number_of_boots
         assert xr.testing.assert_equal(res[0], res[-1]) is None
-
-        def all_equal(check_list):
-            equal = True
-            for b in check_list:
-                equal *= xr.testing.assert_equal(check_list[0], b) is None
-            return equal
-        assert all_equal(res)
-
+        assert PyTestAllEqual(res).is_true()
 
     def test_get_orig_prediction(self, boot_gen):
-        pass
+        path = boot_gen.orig_generator.data_path
+        res = []
+        for pred in boot_gen.get_orig_prediction(path, "forecasts_norm_DEBW107_test.nc"):
+            res.append(pred)
+        assert len(res) == boot_gen.number_of_boots+1
+        assert PyTestAllEqual(res).is_true()
 
     def test_load_shuffled_data(self, boot_gen):
         shuffled_data = boot_gen.load_shuffled_data("DEBW107", ["o3", "temp"])