diff --git a/mlair/workflows/abstract_workflow.py b/mlair/workflows/abstract_workflow.py index 3a627d9f72a5c1c97c35b464af1b0944bc397ea5..c969aa35ebca60aa749a294bcaa5de727407a461 100644 --- a/mlair/workflows/abstract_workflow.py +++ b/mlair/workflows/abstract_workflow.py @@ -3,8 +3,6 @@ __author__ = "Lukas Leufen" __date__ = '2020-06-26' -from collections import OrderedDict - from mlair import RunEnvironment diff --git a/test/test_workflows/test_abstract_workflow.py b/test/test_workflows/test_abstract_workflow.py new file mode 100644 index 0000000000000000000000000000000000000000..c303b7f1e48921df6be5a04ba2fee1cbc69afa35 --- /dev/null +++ b/test/test_workflows/test_abstract_workflow.py @@ -0,0 +1,52 @@ +from mlair.workflows.abstract_workflow import Workflow + +import logging + + +class TestWorkflow: + + def test_init(self): + flow = Workflow() + assert len(flow._registry_kwargs.keys()) == 0 + assert len(flow._registry) == 0 + assert flow._name == "Workflow" + flow = Workflow(name="river") + assert flow._name == "river" + + def test_add(self): + flow = Workflow() + flow.add("stage") + assert len(flow._registry_kwargs.keys()) == 1 + assert len(flow._registry) == 1 + assert len(flow._registry_kwargs[0].keys()) == 0 + flow.add("stagekwargs", test=23, another="string") + assert len(flow._registry_kwargs.keys()) == 2 + assert len(flow._registry) == 2 + assert len(flow._registry_kwargs[1].keys()) == 2 + assert list(flow._registry_kwargs.keys()) == [0, 1] + assert flow._registry == ["stage", "stagekwargs"] + assert list(flow._registry_kwargs[1].keys()) == ["test", "another"] + assert flow._registry_kwargs[1]["another"] == "string" + + def test_run(self, caplog): + caplog.set_level(logging.INFO) + + class A: + def __init__(self, a=3): + self.a = a + logging.info(self.a) + + class B: + def __init__(self): + self.b = 2 + logging.info(self.b) + + flow = Workflow() + flow.add(A, a=6) + flow.add(B) + flow.add(A) + flow.run() + assert caplog.record_tuples[1] == ('root', 20, "Workflow started") + assert caplog.record_tuples[2] == ('root', 20, "6") + assert caplog.record_tuples[3] == ('root', 20, "2") + assert caplog.record_tuples[4] == ('root', 20, "3")