Skip to content
Snippets Groups Projects
Commit 52c4b149 authored by leufen1's avatar leufen1
Browse files

tests for default workflow implemented with little refac, /close #245

parent f4fa78c1
No related branches found
No related tags found
3 merge requests!226Develop,!225Resolve "release v1.2.0",!222Resolve "new tests for default workflow"
Pipeline #55655 passed
...@@ -54,41 +54,10 @@ class DefaultWorkflow(Workflow): ...@@ -54,41 +54,10 @@ class DefaultWorkflow(Workflow):
self.add(PostProcessing) self.add(PostProcessing)
class DefaultWorkflowHPC(Workflow): class DefaultWorkflowHPC(DefaultWorkflow):
"""A default workflow for Jülich HPC systems executing ExperimentSetup, PreProcessing, PartitionCheck, ModelSetup, """A default workflow for Jülich HPC systems executing ExperimentSetup, PreProcessing, PartitionCheck, ModelSetup,
Training and PostProcessing in exact the mentioned ordering.""" Training and PostProcessing in exact the mentioned ordering."""
def __init__(self, stations=None,
train_model=None, create_new_model=None,
window_history_size=None,
experiment_date="testrun",
variables=None, statistics_per_var=None,
start=None, end=None,
target_var=None, target_dim=None,
window_lead_time=None,
dimensions=None,
interpolation_method=None, time_dim=None, limit_nan_fill=None,
train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, test_end=None,
use_all_stations_on_all_data_sets=None, fraction_of_train=None,
experiment_path=None, plot_path=None, forecast_path=None, bootstrap_path=None,
overwrite_local_data=None,
sampling=None,
permute_data_on_training=None, extreme_values=None, extremes_on_right_tail_only=None,
transformation=None,
train_min_length=None, val_min_length=None, test_min_length=None,
evaluate_bootstraps=None, number_of_bootstraps=None, create_new_bootstraps=None,
plot_list=None,
model=None,
batch_size=None,
epochs=None,
data_handler=None, **kwargs):
super().__init__()
# extract all given kwargs arguments
params = remove_items(inspect.getfullargspec(self.__init__).args, "self")
kwargs_default = {k: v for k, v in locals().items() if k in params and v is not None}
self._setup(**kwargs_default, **kwargs)
def _setup(self, **kwargs): def _setup(self, **kwargs):
"""Set up default workflow.""" """Set up default workflow."""
self.add(ExperimentSetup, **kwargs) self.add(ExperimentSetup, **kwargs)
......
from mlair.workflows.default_workflow import DefaultWorkflow, DefaultWorkflowHPC
from mlair.run_modules.experiment_setup import ExperimentSetup
from mlair.run_modules.pre_processing import PreProcessing
from mlair.run_modules.model_setup import ModelSetup
from mlair.run_modules.partition_check import PartitionCheck
from mlair.run_modules.training import Training
from mlair.run_modules.post_processing import PostProcessing
class TestDefaultWorkflow:
def test_init_no_args(self):
flow = DefaultWorkflow()
assert flow._registry[0].__name__ == ExperimentSetup.__name__
assert len(flow._registry_kwargs[0].keys()) == 1
def test_init_with_args(self):
flow = DefaultWorkflow(stations="test", start="2020", model=None)
assert flow._registry[0].__name__ == ExperimentSetup.__name__
assert len(flow._registry_kwargs[0].keys()) == 3
def test_init_with_kwargs(self):
flow = DefaultWorkflow(stations="test", real_kwarg=4)
assert flow._registry[0].__name__ == ExperimentSetup.__name__
assert len(flow._registry_kwargs[0].keys()) == 3
assert list(flow._registry_kwargs[0].keys()) == ["experiment_date", "stations", "real_kwarg"]
def test_setup(self):
flow = DefaultWorkflow()
assert len(flow._registry) == 5
assert flow._registry[0].__name__ == ExperimentSetup.__name__
assert flow._registry[1].__name__ == PreProcessing.__name__
assert flow._registry[2].__name__ == ModelSetup.__name__
assert flow._registry[3].__name__ == Training.__name__
assert flow._registry[4].__name__ == PostProcessing.__name__
class TestDefaultWorkflowHPC:
def test_setup(self):
flow = DefaultWorkflowHPC()
assert len(flow._registry) == 6
assert flow._registry[0].__name__ == ExperimentSetup.__name__
assert flow._registry[1].__name__ == PreProcessing.__name__
assert flow._registry[2].__name__ == PartitionCheck.__name__
assert flow._registry[3].__name__ == ModelSetup.__name__
assert flow._registry[4].__name__ == Training.__name__
assert flow._registry[5].__name__ == PostProcessing.__name__
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment