From fe17a24d459792c6a4ba882b77ef5855e03bded1 Mon Sep 17 00:00:00 2001
From: jrybicki-jsc <j.rybicki@fz-juelich.de>
Date: Mon, 13 Dec 2021 10:58:28 +0100
Subject: [PATCH] taskflow with connection decorator

---
 dags/taskflow.py | 57 +++++++++++++++++++++++++++++++++++++++++++-----
 1 file changed, 51 insertions(+), 6 deletions(-)

diff --git a/dags/taskflow.py b/dags/taskflow.py
index c86066f..a410957 100644
--- a/dags/taskflow.py
+++ b/dags/taskflow.py
@@ -1,4 +1,5 @@
 
+from airflow import settings
 from airflow.decorators import dag, task
 from airflow.models.connection import Connection
 from airflow.providers.ssh.hooks.ssh import SSHHook
@@ -16,8 +17,38 @@ default_args = {
 
 @dag(default_args=default_args, schedule_interval=None, start_date=days_ago(2), tags=['example'])
 def taskflow_example():
+
+    @task
+    def setup(**kwargs):
+        print(f"Setting up the connection")
+        
+        params = kwargs['params']
+        rrid = kwargs['run_id']
+        host = params.get('host')
+        port = params.get('port', 2222)
+        user = params.get('login', 'eflows')
+        key = params.get('key')
+
+        conn_id = f"tmp_connection_{rrid}"
+        extra = {"private_key": key}
+        conn = Connection(
+            conn_id=conn_id,
+            conn_type='ssh',
+            description='Automatically generated Connection',
+            host=host,
+            login=user,
+            port=port,
+            extra=extra
+        )
+
+        session = settings.Session()
+        session.add(conn)
+        session.commit()
+        print(f"Connection {conn_id} created")
+        return conn_id
+
     @task(multiple_outputs=True)
-    def extract(**kwargs):
+    def extract(conn_id, **kwargs):
         connection = Connection.get_connection_from_secrets('default_b2share')
         server = connection.get_uri()
         print(f"Rereiving data from {server}")
@@ -50,11 +81,12 @@ def taskflow_example():
         return name_mappings
 
     @task()
-    def load(files: dict, **kwargs):
+    def load(connection_id, files: dict, **kwargs):
         print(f"Total files downloaded: {len(files)}")
         params = kwargs['params']
         target = params.get('target', '/tmp/')
-        connection_id = params.get('connection', 'default_ssh')
+        
+        print(f"Using {connection_id} connection")
         
         ssh_hook = SSHHook(ssh_conn_id=connection_id)
         with ssh_hook.get_conn() as ssh_client:
@@ -65,9 +97,22 @@ def taskflow_example():
                 # or separate cleanup task?
                 os.unlink(local)
 
-    data = extract()
-    files = transform(data)
-    load(files)
+        return connection_id
+
+    @task()
+    def remove(conn_id):
+        print(f"Removing conneciton {conn_id}")
+        session = settings.Session()
+        for con in session.query(Connection).all():
+            print(con)
+
+        session.query(Connection).filter(Connection.conn_id == conn_id).delete()
+        session.commit()
 
+    conn_id = setup()
+    data = extract(conn_id)
+    files = transform(data)
+    ucid = load(connection_id = conn_id, files=files)
+    remove(conn_id=ucid)
 
 dag = taskflow_example()
-- 
GitLab