from mlair.workflows.default_workflow import DefaultWorkflow, DefaultWorkflowHPC from mlair.run_modules.experiment_setup import ExperimentSetup from mlair.run_modules.pre_processing import PreProcessing from mlair.run_modules.model_setup import ModelSetup from mlair.run_modules.partition_check import PartitionCheck from mlair.run_modules.training import Training from mlair.run_modules.post_processing import PostProcessing class TestDefaultWorkflow: def test_init_no_args(self): flow = DefaultWorkflow() assert flow._registry[0].__name__ == ExperimentSetup.__name__ assert len(flow._registry_kwargs[0].keys()) == 1 def test_init_with_args(self): flow = DefaultWorkflow(stations="test", start="2020", model=None) assert flow._registry[0].__name__ == ExperimentSetup.__name__ assert len(flow._registry_kwargs[0].keys()) == 3 def test_init_with_kwargs(self): flow = DefaultWorkflow(stations="test", real_kwarg=4) assert flow._registry[0].__name__ == ExperimentSetup.__name__ assert len(flow._registry_kwargs[0].keys()) == 3 assert list(flow._registry_kwargs[0].keys()) == ["experiment_date", "stations", "real_kwarg"] def test_setup(self): flow = DefaultWorkflow() assert len(flow._registry) == 5 assert flow._registry[0].__name__ == ExperimentSetup.__name__ assert flow._registry[1].__name__ == PreProcessing.__name__ assert flow._registry[2].__name__ == ModelSetup.__name__ assert flow._registry[3].__name__ == Training.__name__ assert flow._registry[4].__name__ == PostProcessing.__name__ class TestDefaultWorkflowHPC: def test_setup(self): flow = DefaultWorkflowHPC() assert len(flow._registry) == 6 assert flow._registry[0].__name__ == ExperimentSetup.__name__ assert flow._registry[1].__name__ == PreProcessing.__name__ assert flow._registry[2].__name__ == PartitionCheck.__name__ assert flow._registry[3].__name__ == ModelSetup.__name__ assert flow._registry[4].__name__ == Training.__name__ assert flow._registry[5].__name__ == PostProcessing.__name__