Skip to content
Snippets Groups Projects
Commit 09f7e2d2 authored by lukas leufen's avatar lukas leufen
Browse files

corrected some failing tests

parent 5843d2f6
Branches
Tags
2 merge requests!37include new development,!29Lukas issue030 feat continue training
Pipeline #28882 passed
......@@ -42,16 +42,16 @@ class PreProcessing(RunEnvironment):
self.report_pre_processing()
def report_pre_processing(self):
logging.info(20 * '##')
logging.debug(20 * '##')
n_train = len(self.data_store.get('generator', 'general.train'))
n_val = len(self.data_store.get('generator', 'general.val'))
n_test = len(self.data_store.get('generator', 'general.test'))
n_total = n_train + n_val + n_test
logging.info(f"Number of all stations: {n_total}")
logging.info(f"Number of training stations: {n_train}")
logging.info(f"Number of val stations: {n_val}")
logging.info(f"Number of test stations: {n_test}")
logging.info(f"TEST SHAPE OF GENERATOR CALL: {self.data_store.get('generator', 'general.test')[0][0].shape}"
logging.debug(f"Number of all stations: {n_total}")
logging.debug(f"Number of training stations: {n_train}")
logging.debug(f"Number of val stations: {n_val}")
logging.debug(f"Number of test stations: {n_test}")
logging.debug(f"TEST SHAPE OF GENERATOR CALL: {self.data_store.get('generator', 'general.test')[0][0].shape}"
f"{self.data_store.get('generator', 'general.test')[0][1].shape}")
def split_train_val_test(self):
......
import pytest
from src.model_modules.keras_extensions import *
from src.helpers import l_p_loss
import keras
import numpy as np
......
......@@ -18,6 +18,9 @@ class TestModelSetup:
super(ModelSetup, obj).__init__()
obj.scope = "general.modeltest"
obj.model = None
obj.callbacks_name = "placeholder_%s_str.pickle"
obj.data_store.set("lr_decay", "dummy_str", "general.model")
obj.data_store.set("hist", "dummy_str", "general.model")
yield obj
RunEnvironment().__del__()
......
......@@ -47,14 +47,15 @@ class TestPreProcessing:
assert obj_with_exp_setup.data_store.search_name("generator") == []
assert obj_with_exp_setup._run() is None
assert obj_with_exp_setup.data_store.search_name("generator") == sorted(["general.train", "general.val",
"general.test"])
"general.train_val", "general.test"])
def test_split_train_val_test(self, obj_with_exp_setup):
assert obj_with_exp_setup.data_store.search_name("generator") == []
obj_with_exp_setup.split_train_val_test()
data_store = obj_with_exp_setup.data_store
assert data_store.search_scope("general.train") == sorted(["generator", "start", "end", "stations"])
assert data_store.search_name("generator") == sorted(["general.train", "general.val", "general.test"])
assert data_store.search_name("generator") == sorted(["general.train", "general.val", "general.test",
"general.train_val"])
def test_create_set_split_not_all_stations(self, caplog, obj_with_exp_setup):
caplog.set_level(logging.DEBUG)
......@@ -93,10 +94,11 @@ class TestPreProcessing:
def test_split_set_indices(self, obj_super_init):
dummy_list = list(range(0, 15))
train, val, test = obj_super_init.split_set_indices(len(dummy_list), 0.9)
train, val, test, train_val = obj_super_init.split_set_indices(len(dummy_list), 0.9)
assert dummy_list[train] == list(range(0, 10))
assert dummy_list[val] == list(range(10, 13))
assert dummy_list[test] == list(range(13, 15))
assert dummy_list[train_val] == list(range(0, 13))
def test_create_args_dict_default_scope(self, obj_super_init):
assert obj_super_init.data_store.create_args_dict(["NAME1", "NAME2"]) == {"NAME1": 1, "NAME2": 2}
......
......@@ -15,7 +15,7 @@ from src.run_modules.run_environment import RunEnvironment
from src.data_handling.data_distributor import Distributor
from src.data_handling.data_generator import DataGenerator
from src.helpers import PyTestRegex
from src.model_modules.keras_extensions import LearningRateDecay
from src.model_modules.keras_extensions import LearningRateDecay, HistoryAdvanced
def my_test_model(activation, window_history_size, channels, dropout_rate, add_minor_branch=False):
......@@ -49,6 +49,7 @@ class TestTraining:
obj.epochs = 2
obj.checkpoint = checkpoint
obj.lr_sc = LearningRateDecay()
obj.hist = HistoryAdvanced()
obj.experiment_name = "TestExperiment"
obj.data_store.set("generator", mock.MagicMock(return_value="mock_train_gen"), "general.train")
obj.data_store.set("generator", mock.MagicMock(return_value="mock_val_gen"), "general.val")
......@@ -133,6 +134,7 @@ class TestTraining:
obj.data_store.set("epochs", 2, "general.model")
obj.data_store.set("checkpoint", checkpoint, "general.model")
obj.data_store.set("lr_decay", LearningRateDecay(), "general.model")
obj.data_store.set("hist", HistoryAdvanced(), "general.model")
obj.data_store.set("experiment_name", "TestExperiment", "general")
obj.data_store.set("experiment_path", path, "general")
path_plot = os.path.join(path, "plots")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment