from mlair.workflows.abstract_workflow import Workflow

import logging


class TestWorkflow:

    def test_init(self):
        flow = Workflow()
        assert len(flow._registry_kwargs.keys()) == 0
        assert len(flow._registry) == 0
        assert flow._name == "Workflow"
        flow = Workflow(name="river")
        assert flow._name == "river"

    def test_add(self):
        flow = Workflow()
        flow.add("stage")
        assert len(flow._registry_kwargs.keys()) == 1
        assert len(flow._registry) == 1
        assert len(flow._registry_kwargs[0].keys()) == 0
        flow.add("stagekwargs", test=23, another="string")
        assert len(flow._registry_kwargs.keys()) == 2
        assert len(flow._registry) == 2
        assert len(flow._registry_kwargs[1].keys()) == 2
        assert list(flow._registry_kwargs.keys()) == [0, 1]
        assert flow._registry == ["stage", "stagekwargs"]
        assert list(flow._registry_kwargs[1].keys()) == ["test", "another"]
        assert flow._registry_kwargs[1]["another"] == "string"

    def test_run(self, caplog):
        caplog.set_level(logging.INFO)

        class A:
            def __init__(self, a=3):
                self.a = a
                logging.info(self.a)

        class B:
            def __init__(self):
                self.b = 2
                logging.info(self.b)

        flow = Workflow()
        flow.add(A, a=6)
        flow.add(B)
        flow.add(A)
        flow.run()
        assert caplog.record_tuples[1] == ('root', 20, "Workflow started")
        assert caplog.record_tuples[2] == ('root', 20, "6")
        assert caplog.record_tuples[3] == ('root', 20, "2")
        assert caplog.record_tuples[4] == ('root', 20, "3")