Skip to content
Snippets Groups Projects
Commit c60b26d2 authored by lukas leufen's avatar lukas leufen
Browse files

Merge branch 'lukas_issue245_test_new-tests-for-default-workflow' into 'develop'

Resolve "new tests for default workflow"

See merge request toar/mlair!222
parents f4fa78c1 52c4b149
Branches
Tags
3 merge requests!226Develop,!225Resolve "release v1.2.0",!222Resolve "new tests for default workflow"
Pipeline #55662 passed with warnings
...@@ -54,41 +54,10 @@ class DefaultWorkflow(Workflow): ...@@ -54,41 +54,10 @@ class DefaultWorkflow(Workflow):
self.add(PostProcessing) self.add(PostProcessing)
class DefaultWorkflowHPC(Workflow): class DefaultWorkflowHPC(DefaultWorkflow):
"""A default workflow for Jülich HPC systems executing ExperimentSetup, PreProcessing, PartitionCheck, ModelSetup, """A default workflow for Jülich HPC systems executing ExperimentSetup, PreProcessing, PartitionCheck, ModelSetup,
Training and PostProcessing in exact the mentioned ordering.""" Training and PostProcessing in exact the mentioned ordering."""
def __init__(self, stations=None,
train_model=None, create_new_model=None,
window_history_size=None,
experiment_date="testrun",
variables=None, statistics_per_var=None,
start=None, end=None,
target_var=None, target_dim=None,
window_lead_time=None,
dimensions=None,
interpolation_method=None, time_dim=None, limit_nan_fill=None,
train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, test_end=None,
use_all_stations_on_all_data_sets=None, fraction_of_train=None,
experiment_path=None, plot_path=None, forecast_path=None, bootstrap_path=None,
overwrite_local_data=None,
sampling=None,
permute_data_on_training=None, extreme_values=None, extremes_on_right_tail_only=None,
transformation=None,
train_min_length=None, val_min_length=None, test_min_length=None,
evaluate_bootstraps=None, number_of_bootstraps=None, create_new_bootstraps=None,
plot_list=None,
model=None,
batch_size=None,
epochs=None,
data_handler=None, **kwargs):
super().__init__()
# extract all given kwargs arguments
params = remove_items(inspect.getfullargspec(self.__init__).args, "self")
kwargs_default = {k: v for k, v in locals().items() if k in params and v is not None}
self._setup(**kwargs_default, **kwargs)
def _setup(self, **kwargs): def _setup(self, **kwargs):
"""Set up default workflow.""" """Set up default workflow."""
self.add(ExperimentSetup, **kwargs) self.add(ExperimentSetup, **kwargs)
......
from mlair.workflows.default_workflow import DefaultWorkflow, DefaultWorkflowHPC
from mlair.run_modules.experiment_setup import ExperimentSetup
from mlair.run_modules.pre_processing import PreProcessing
from mlair.run_modules.model_setup import ModelSetup
from mlair.run_modules.partition_check import PartitionCheck
from mlair.run_modules.training import Training
from mlair.run_modules.post_processing import PostProcessing
class TestDefaultWorkflow:
def test_init_no_args(self):
flow = DefaultWorkflow()
assert flow._registry[0].__name__ == ExperimentSetup.__name__
assert len(flow._registry_kwargs[0].keys()) == 1
def test_init_with_args(self):
flow = DefaultWorkflow(stations="test", start="2020", model=None)
assert flow._registry[0].__name__ == ExperimentSetup.__name__
assert len(flow._registry_kwargs[0].keys()) == 3
def test_init_with_kwargs(self):
flow = DefaultWorkflow(stations="test", real_kwarg=4)
assert flow._registry[0].__name__ == ExperimentSetup.__name__
assert len(flow._registry_kwargs[0].keys()) == 3
assert list(flow._registry_kwargs[0].keys()) == ["experiment_date", "stations", "real_kwarg"]
def test_setup(self):
flow = DefaultWorkflow()
assert len(flow._registry) == 5
assert flow._registry[0].__name__ == ExperimentSetup.__name__
assert flow._registry[1].__name__ == PreProcessing.__name__
assert flow._registry[2].__name__ == ModelSetup.__name__
assert flow._registry[3].__name__ == Training.__name__
assert flow._registry[4].__name__ == PostProcessing.__name__
class TestDefaultWorkflowHPC:
def test_setup(self):
flow = DefaultWorkflowHPC()
assert len(flow._registry) == 6
assert flow._registry[0].__name__ == ExperimentSetup.__name__
assert flow._registry[1].__name__ == PreProcessing.__name__
assert flow._registry[2].__name__ == PartitionCheck.__name__
assert flow._registry[3].__name__ == ModelSetup.__name__
assert flow._registry[4].__name__ == Training.__name__
assert flow._registry[5].__name__ == PostProcessing.__name__
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment