diff --git a/mlair/workflows/default_workflow.py b/mlair/workflows/default_workflow.py index 4d113190fdc90ec852d7db2b33459b9162867a24..5894555a6af52299efcd8d88d76c0d3791a1599e 100644 --- a/mlair/workflows/default_workflow.py +++ b/mlair/workflows/default_workflow.py @@ -54,41 +54,10 @@ class DefaultWorkflow(Workflow): self.add(PostProcessing) -class DefaultWorkflowHPC(Workflow): +class DefaultWorkflowHPC(DefaultWorkflow): """A default workflow for Jülich HPC systems executing ExperimentSetup, PreProcessing, PartitionCheck, ModelSetup, Training and PostProcessing in exact the mentioned ordering.""" - def __init__(self, stations=None, - train_model=None, create_new_model=None, - window_history_size=None, - experiment_date="testrun", - variables=None, statistics_per_var=None, - start=None, end=None, - target_var=None, target_dim=None, - window_lead_time=None, - dimensions=None, - interpolation_method=None, time_dim=None, limit_nan_fill=None, - train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, test_end=None, - use_all_stations_on_all_data_sets=None, fraction_of_train=None, - experiment_path=None, plot_path=None, forecast_path=None, bootstrap_path=None, - overwrite_local_data=None, - sampling=None, - permute_data_on_training=None, extreme_values=None, extremes_on_right_tail_only=None, - transformation=None, - train_min_length=None, val_min_length=None, test_min_length=None, - evaluate_bootstraps=None, number_of_bootstraps=None, create_new_bootstraps=None, - plot_list=None, - model=None, - batch_size=None, - epochs=None, - data_handler=None, **kwargs): - super().__init__() - - # extract all given kwargs arguments - params = remove_items(inspect.getfullargspec(self.__init__).args, "self") - kwargs_default = {k: v for k, v in locals().items() if k in params and v is not None} - self._setup(**kwargs_default, **kwargs) - def _setup(self, **kwargs): """Set up default workflow.""" self.add(ExperimentSetup, **kwargs) diff --git a/test/test_workflows/test_default_workflow.py b/test/test_workflows/test_default_workflow.py new file mode 100644 index 0000000000000000000000000000000000000000..c7c198a4821f779329b9f5f19b04e757d8ebc7da --- /dev/null +++ b/test/test_workflows/test_default_workflow.py @@ -0,0 +1,48 @@ +from mlair.workflows.default_workflow import DefaultWorkflow, DefaultWorkflowHPC +from mlair.run_modules.experiment_setup import ExperimentSetup +from mlair.run_modules.pre_processing import PreProcessing +from mlair.run_modules.model_setup import ModelSetup +from mlair.run_modules.partition_check import PartitionCheck +from mlair.run_modules.training import Training +from mlair.run_modules.post_processing import PostProcessing + + +class TestDefaultWorkflow: + + def test_init_no_args(self): + flow = DefaultWorkflow() + assert flow._registry[0].__name__ == ExperimentSetup.__name__ + assert len(flow._registry_kwargs[0].keys()) == 1 + + def test_init_with_args(self): + flow = DefaultWorkflow(stations="test", start="2020", model=None) + assert flow._registry[0].__name__ == ExperimentSetup.__name__ + assert len(flow._registry_kwargs[0].keys()) == 3 + + def test_init_with_kwargs(self): + flow = DefaultWorkflow(stations="test", real_kwarg=4) + assert flow._registry[0].__name__ == ExperimentSetup.__name__ + assert len(flow._registry_kwargs[0].keys()) == 3 + assert list(flow._registry_kwargs[0].keys()) == ["experiment_date", "stations", "real_kwarg"] + + def test_setup(self): + flow = DefaultWorkflow() + assert len(flow._registry) == 5 + assert flow._registry[0].__name__ == ExperimentSetup.__name__ + assert flow._registry[1].__name__ == PreProcessing.__name__ + assert flow._registry[2].__name__ == ModelSetup.__name__ + assert flow._registry[3].__name__ == Training.__name__ + assert flow._registry[4].__name__ == PostProcessing.__name__ + + +class TestDefaultWorkflowHPC: + + def test_setup(self): + flow = DefaultWorkflowHPC() + assert len(flow._registry) == 6 + assert flow._registry[0].__name__ == ExperimentSetup.__name__ + assert flow._registry[1].__name__ == PreProcessing.__name__ + assert flow._registry[2].__name__ == PartitionCheck.__name__ + assert flow._registry[3].__name__ == ModelSetup.__name__ + assert flow._registry[4].__name__ == Training.__name__ + assert flow._registry[5].__name__ == PostProcessing.__name__