Skip to content
Snippets Groups Projects
Commit 09f7e2d2 authored by lukas leufen's avatar lukas leufen
Browse files

corrected some failing tests

parent 5843d2f6
No related branches found
No related tags found
2 merge requests!37include new development,!29Lukas issue030 feat continue training
Pipeline #28882 passed
......@@ -42,16 +42,16 @@ class PreProcessing(RunEnvironment):
self.report_pre_processing()
def report_pre_processing(self):
logging.info(20 * '##')
logging.debug(20 * '##')
n_train = len(self.data_store.get('generator', 'general.train'))
n_val = len(self.data_store.get('generator', 'general.val'))
n_test = len(self.data_store.get('generator', 'general.test'))
n_total = n_train + n_val + n_test
logging.info(f"Number of all stations: {n_total}")
logging.info(f"Number of training stations: {n_train}")
logging.info(f"Number of val stations: {n_val}")
logging.info(f"Number of test stations: {n_test}")
logging.info(f"TEST SHAPE OF GENERATOR CALL: {self.data_store.get('generator', 'general.test')[0][0].shape}"
logging.debug(f"Number of all stations: {n_total}")
logging.debug(f"Number of training stations: {n_train}")
logging.debug(f"Number of val stations: {n_val}")
logging.debug(f"Number of test stations: {n_test}")
logging.debug(f"TEST SHAPE OF GENERATOR CALL: {self.data_store.get('generator', 'general.test')[0][0].shape}"
f"{self.data_store.get('generator', 'general.test')[0][1].shape}")
def split_train_val_test(self):
......
import pytest
from src.model_modules.keras_extensions import *
from src.helpers import l_p_loss
import keras
import numpy as np
......
......@@ -18,6 +18,9 @@ class TestModelSetup:
super(ModelSetup, obj).__init__()
obj.scope = "general.modeltest"
obj.model = None
obj.callbacks_name = "placeholder_%s_str.pickle"
obj.data_store.set("lr_decay", "dummy_str", "general.model")
obj.data_store.set("hist", "dummy_str", "general.model")
yield obj
RunEnvironment().__del__()
......
......@@ -47,14 +47,15 @@ class TestPreProcessing:
assert obj_with_exp_setup.data_store.search_name("generator") == []
assert obj_with_exp_setup._run() is None
assert obj_with_exp_setup.data_store.search_name("generator") == sorted(["general.train", "general.val",
"general.test"])
"general.train_val", "general.test"])
def test_split_train_val_test(self, obj_with_exp_setup):
assert obj_with_exp_setup.data_store.search_name("generator") == []
obj_with_exp_setup.split_train_val_test()
data_store = obj_with_exp_setup.data_store
assert data_store.search_scope("general.train") == sorted(["generator", "start", "end", "stations"])
assert data_store.search_name("generator") == sorted(["general.train", "general.val", "general.test"])
assert data_store.search_name("generator") == sorted(["general.train", "general.val", "general.test",
"general.train_val"])
def test_create_set_split_not_all_stations(self, caplog, obj_with_exp_setup):
caplog.set_level(logging.DEBUG)
......@@ -93,10 +94,11 @@ class TestPreProcessing:
def test_split_set_indices(self, obj_super_init):
dummy_list = list(range(0, 15))
train, val, test = obj_super_init.split_set_indices(len(dummy_list), 0.9)
train, val, test, train_val = obj_super_init.split_set_indices(len(dummy_list), 0.9)
assert dummy_list[train] == list(range(0, 10))
assert dummy_list[val] == list(range(10, 13))
assert dummy_list[test] == list(range(13, 15))
assert dummy_list[train_val] == list(range(0, 13))
def test_create_args_dict_default_scope(self, obj_super_init):
assert obj_super_init.data_store.create_args_dict(["NAME1", "NAME2"]) == {"NAME1": 1, "NAME2": 2}
......
......@@ -15,7 +15,7 @@ from src.run_modules.run_environment import RunEnvironment
from src.data_handling.data_distributor import Distributor
from src.data_handling.data_generator import DataGenerator
from src.helpers import PyTestRegex
from src.model_modules.keras_extensions import LearningRateDecay
from src.model_modules.keras_extensions import LearningRateDecay, HistoryAdvanced
def my_test_model(activation, window_history_size, channels, dropout_rate, add_minor_branch=False):
......@@ -49,6 +49,7 @@ class TestTraining:
obj.epochs = 2
obj.checkpoint = checkpoint
obj.lr_sc = LearningRateDecay()
obj.hist = HistoryAdvanced()
obj.experiment_name = "TestExperiment"
obj.data_store.set("generator", mock.MagicMock(return_value="mock_train_gen"), "general.train")
obj.data_store.set("generator", mock.MagicMock(return_value="mock_val_gen"), "general.val")
......@@ -133,6 +134,7 @@ class TestTraining:
obj.data_store.set("epochs", 2, "general.model")
obj.data_store.set("checkpoint", checkpoint, "general.model")
obj.data_store.set("lr_decay", LearningRateDecay(), "general.model")
obj.data_store.set("hist", HistoryAdvanced(), "general.model")
obj.data_store.set("experiment_name", "TestExperiment", "general")
obj.data_store.set("experiment_path", path, "general")
path_plot = os.path.join(path, "plots")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment