From fe17a24d459792c6a4ba882b77ef5855e03bded1 Mon Sep 17 00:00:00 2001 From: jrybicki-jsc <j.rybicki@fz-juelich.de> Date: Mon, 13 Dec 2021 10:58:28 +0100 Subject: [PATCH] taskflow with connection decorator --- dags/taskflow.py | 57 +++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 51 insertions(+), 6 deletions(-) diff --git a/dags/taskflow.py b/dags/taskflow.py index c86066f..a410957 100644 --- a/dags/taskflow.py +++ b/dags/taskflow.py @@ -1,4 +1,5 @@ +from airflow import settings from airflow.decorators import dag, task from airflow.models.connection import Connection from airflow.providers.ssh.hooks.ssh import SSHHook @@ -16,8 +17,38 @@ default_args = { @dag(default_args=default_args, schedule_interval=None, start_date=days_ago(2), tags=['example']) def taskflow_example(): + + @task + def setup(**kwargs): + print(f"Setting up the connection") + + params = kwargs['params'] + rrid = kwargs['run_id'] + host = params.get('host') + port = params.get('port', 2222) + user = params.get('login', 'eflows') + key = params.get('key') + + conn_id = f"tmp_connection_{rrid}" + extra = {"private_key": key} + conn = Connection( + conn_id=conn_id, + conn_type='ssh', + description='Automatically generated Connection', + host=host, + login=user, + port=port, + extra=extra + ) + + session = settings.Session() + session.add(conn) + session.commit() + print(f"Connection {conn_id} created") + return conn_id + @task(multiple_outputs=True) - def extract(**kwargs): + def extract(conn_id, **kwargs): connection = Connection.get_connection_from_secrets('default_b2share') server = connection.get_uri() print(f"Rereiving data from {server}") @@ -50,11 +81,12 @@ def taskflow_example(): return name_mappings @task() - def load(files: dict, **kwargs): + def load(connection_id, files: dict, **kwargs): print(f"Total files downloaded: {len(files)}") params = kwargs['params'] target = params.get('target', '/tmp/') - connection_id = params.get('connection', 'default_ssh') + + print(f"Using {connection_id} connection") ssh_hook = SSHHook(ssh_conn_id=connection_id) with ssh_hook.get_conn() as ssh_client: @@ -65,9 +97,22 @@ def taskflow_example(): # or separate cleanup task? os.unlink(local) - data = extract() - files = transform(data) - load(files) + return connection_id + + @task() + def remove(conn_id): + print(f"Removing conneciton {conn_id}") + session = settings.Session() + for con in session.query(Connection).all(): + print(con) + + session.query(Connection).filter(Connection.conn_id == conn_id).delete() + session.commit() + conn_id = setup() + data = extract(conn_id) + files = transform(data) + ucid = load(connection_id = conn_id, files=files) + remove(conn_id=ucid) dag = taskflow_example() -- GitLab