from mlair.workflows.default_workflow import DefaultWorkflow, DefaultWorkflowHPC
from mlair.run_modules.experiment_setup import ExperimentSetup
from mlair.run_modules.pre_processing import PreProcessing
from mlair.run_modules.model_setup import ModelSetup
from mlair.run_modules.partition_check import PartitionCheck
from mlair.run_modules.training import Training
from mlair.run_modules.post_processing import PostProcessing


class TestDefaultWorkflow:

    def test_init_no_args(self):
        flow = DefaultWorkflow()
        assert flow._registry[0].__name__ == ExperimentSetup.__name__
        assert len(flow._registry_kwargs[0].keys()) == 1

    def test_init_with_args(self):
        flow = DefaultWorkflow(stations="test", start="2020", model=None)
        assert flow._registry[0].__name__ == ExperimentSetup.__name__
        assert len(flow._registry_kwargs[0].keys()) == 3

    def test_init_with_kwargs(self):
        flow = DefaultWorkflow(stations="test", real_kwarg=4)
        assert flow._registry[0].__name__ == ExperimentSetup.__name__
        assert len(flow._registry_kwargs[0].keys()) == 3
        assert list(flow._registry_kwargs[0].keys()) == ["experiment_date", "stations", "real_kwarg"]

    def test_setup(self):
        flow = DefaultWorkflow()
        assert len(flow._registry) == 5
        assert flow._registry[0].__name__ == ExperimentSetup.__name__
        assert flow._registry[1].__name__ == PreProcessing.__name__
        assert flow._registry[2].__name__ == ModelSetup.__name__
        assert flow._registry[3].__name__ == Training.__name__
        assert flow._registry[4].__name__ == PostProcessing.__name__


class TestDefaultWorkflowHPC:

    def test_setup(self):
        flow = DefaultWorkflowHPC()
        assert len(flow._registry) == 6
        assert flow._registry[0].__name__ == ExperimentSetup.__name__
        assert flow._registry[1].__name__ == PreProcessing.__name__
        assert flow._registry[2].__name__ == PartitionCheck.__name__
        assert flow._registry[3].__name__ == ModelSetup.__name__
        assert flow._registry[4].__name__ == Training.__name__
        assert flow._registry[5].__name__ == PostProcessing.__name__