From 969c9f79613aef934a1f33fb44ecc7378ea8411c Mon Sep 17 00:00:00 2001 From: leufen1 <l.leufen@fz-juelich.de> Date: Fri, 18 Dec 2020 09:43:35 +0100 Subject: [PATCH] tests for abstract workflow implemented, /close #244 --- mlair/workflows/abstract_workflow.py | 2 - test/test_workflows/test_abstract_workflow.py | 52 +++++++++++++++++++ 2 files changed, 52 insertions(+), 2 deletions(-) create mode 100644 test/test_workflows/test_abstract_workflow.py diff --git a/mlair/workflows/abstract_workflow.py b/mlair/workflows/abstract_workflow.py index 3a627d9f..c969aa35 100644 --- a/mlair/workflows/abstract_workflow.py +++ b/mlair/workflows/abstract_workflow.py @@ -3,8 +3,6 @@ __author__ = "Lukas Leufen" __date__ = '2020-06-26' -from collections import OrderedDict - from mlair import RunEnvironment diff --git a/test/test_workflows/test_abstract_workflow.py b/test/test_workflows/test_abstract_workflow.py new file mode 100644 index 00000000..c303b7f1 --- /dev/null +++ b/test/test_workflows/test_abstract_workflow.py @@ -0,0 +1,52 @@ +from mlair.workflows.abstract_workflow import Workflow + +import logging + + +class TestWorkflow: + + def test_init(self): + flow = Workflow() + assert len(flow._registry_kwargs.keys()) == 0 + assert len(flow._registry) == 0 + assert flow._name == "Workflow" + flow = Workflow(name="river") + assert flow._name == "river" + + def test_add(self): + flow = Workflow() + flow.add("stage") + assert len(flow._registry_kwargs.keys()) == 1 + assert len(flow._registry) == 1 + assert len(flow._registry_kwargs[0].keys()) == 0 + flow.add("stagekwargs", test=23, another="string") + assert len(flow._registry_kwargs.keys()) == 2 + assert len(flow._registry) == 2 + assert len(flow._registry_kwargs[1].keys()) == 2 + assert list(flow._registry_kwargs.keys()) == [0, 1] + assert flow._registry == ["stage", "stagekwargs"] + assert list(flow._registry_kwargs[1].keys()) == ["test", "another"] + assert flow._registry_kwargs[1]["another"] == "string" + + def test_run(self, caplog): + caplog.set_level(logging.INFO) + + class A: + def __init__(self, a=3): + self.a = a + logging.info(self.a) + + class B: + def __init__(self): + self.b = 2 + logging.info(self.b) + + flow = Workflow() + flow.add(A, a=6) + flow.add(B) + flow.add(A) + flow.run() + assert caplog.record_tuples[1] == ('root', 20, "Workflow started") + assert caplog.record_tuples[2] == ('root', 20, "6") + assert caplog.record_tuples[3] == ('root', 20, "2") + assert caplog.record_tuples[4] == ('root', 20, "3") -- GitLab