From ac97612342fc6d07a67ae31ffba5b03886d3f57a Mon Sep 17 00:00:00 2001
From: jrybicki-jsc <j.rybicki@fz-juelich.de>
Date: Wed, 27 Jul 2022 08:08:06 +0200
Subject: [PATCH] first take on webdav transfers

---
 dags/uploadflow.py     | 37 ++++++++++++---------
 dags/webdav_example.py | 74 ++++++++++++++++++++++++++++++++++++++++++
 requirements.txt       |  1 +
 3 files changed, 96 insertions(+), 16 deletions(-)
 create mode 100644 dags/webdav_example.py

diff --git a/dags/uploadflow.py b/dags/uploadflow.py
index 72c4d7e..6c65363 100644
--- a/dags/uploadflow.py
+++ b/dags/uploadflow.py
@@ -38,20 +38,12 @@ def create_template(hrespo):
         "open_access": hrespo['open_access'] == "True"
     }
 
-
-@dag(default_args=default_args, schedule_interval=None, start_date=days_ago(2), tags=['example'])
-def upload_example():
-
-    @task()
-    def load(connection_id, **kwargs):
-        params = kwargs['params']
-        target = Variable.get("working_dir", default_var='/tmp/')
-        source = params.get('source', '/tmp/')
-
-        ssh_hook = get_connection(conn_id=connection_id, **kwargs)
-        with ssh_hook.get_conn() as ssh_client:
+def ssh2local_copy(ssh_hook, source: str, target: str):
+    with ssh_hook.get_conn() as ssh_client:
             sftp_client = ssh_client.open_sftp()
             lst = sftp_client.listdir(path=source)
+            
+            print(f"{len(lst)} objects in {source}")
             mappings = dict()
             for fname in lst:
                 local = tempfile.mktemp(prefix='dls', dir=target)
@@ -61,10 +53,23 @@ def upload_example():
                     print(f"{full_name} is a directory. Skipping")
                     continue
 
-                print(f"Copying {connection_id}:{full_name} --> {local}")
-                sftp_client.get(os.path.join(source, fname), local)
+                print(f"Copying {full_name} --> {local}")
+                sftp_client.get(full_name, local)
                 mappings[local] = fname
 
+    return mappings
+
+@dag(default_args=default_args, schedule_interval=None, start_date=days_ago(2), tags=['example'])
+def upload_example():
+
+    @task()
+    def load(connection_id, **kwargs):
+        params = kwargs['params']
+        target = Variable.get("working_dir", default_var='/tmp/')
+        source = params.get('source', '/tmp/')
+        ssh_hook = get_connection(conn_id=connection_id, **kwargs)
+        
+        mappings = ssh2local_copy(ssh_hook=ssh_hook, source=source, target=target)
         return mappings
 
     @task()
@@ -135,7 +140,7 @@ def upload_example():
         except ConnectionError as e:
             print('Registration failed', e)
             return -1
-
+    
     setup_task = PythonOperator(python_callable=setup, task_id='setup_connection')
     a_id = setup_task.output['return_value']
 
@@ -146,7 +151,7 @@ def upload_example():
                         'conn_id': a_id}, task_id='cleanup')
 
     reg = register(object_url=uid)
-
+    
     setup_task >> files >> uid >> reg >> en
 
 
diff --git a/dags/webdav_example.py b/dags/webdav_example.py
new file mode 100644
index 0000000..9829bfd
--- /dev/null
+++ b/dags/webdav_example.py
@@ -0,0 +1,74 @@
+import os
+
+from airflow.decorators import dag, task
+from airflow.models import Variable
+from airflow.models.connection import Connection
+from airflow.operators.python import PythonOperator
+from airflow.utils.dates import days_ago
+from webdav3.client import Client
+
+from uploadflow import ssh2local_copy
+from decors import get_connection, remove, setup
+
+default_args = {
+    'owner': 'airflow',
+}
+
+
+@dag(default_args=default_args, schedule_interval=None, start_date=days_ago(2), tags=['example'])
+def webdav_upload():
+
+    @task()
+    def download(connection_id, **kwargs):
+        
+        params = kwargs['params']
+        target = Variable.get("working_dir", default_var='/tmp/')
+        source = params.get('source', '/tmp/')
+        ssh_hook = get_connection(conn_id=connection_id, **kwargs)
+
+        mappings = ssh2local_copy(ssh_hook=ssh_hook, source=source, target=target)
+        
+        return mappings
+
+    @task()
+    def load(mappings, **kwargs):
+        params = kwargs['params']
+        target = params.get('target', '/airflow-test')
+        connection = Connection.get_connection_from_secrets('b2drop_webdav')
+        options = {'webdav_hostname': f"https://{connection.host}{connection.schema}",
+                   'webdav_login': connection.login,
+                   'webdav_password': connection.get_password()
+                   }
+        print(f"Translated http to webdav: {options}")
+        client = Client(options)
+        res = client.mkdir(target)
+        print(f"Creating {target}: {'ok' if res else 'failed'}")
+
+        print(f"Starting upload -> {target}")
+        for [local, true_name] in mappings.items():
+            full_name = full_name = os.path.join(target, true_name)
+            print(f"Processing {local} --> {full_name}")
+            client.upload_sync(remote_path=full_name, local_path=local)
+
+            # delete local
+            os.unlink(local)
+
+        return True
+
+    @task
+    def print_stats(res):
+        print('Finished')
+
+    setup_task = PythonOperator(
+        python_callable=setup, task_id='setup_connection')
+    a_id = setup_task.output['return_value']
+
+    mappings = download(connection_id=a_id)
+    res = load(mappings=mappings)
+    
+    en = PythonOperator(python_callable=remove, op_kwargs={
+                        'conn_id': a_id}, task_id='cleanup')
+    res >> en
+
+
+dag = webdav_upload()
diff --git a/requirements.txt b/requirements.txt
index d0b08be..920dac0 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,7 @@
 requests
 urllib3
 plyvel
+webdavclient3
 apache-airflow-providers-ssh
 apache-airflow-providers-http
 apache-airflow-providers-sftp
-- 
GitLab