From ea9d985d724373d33da7d7c92cdec7cc1188aa2f Mon Sep 17 00:00:00 2001
From: jrybicki-jsc <j.rybicki@fz-juelich.de>
Date: Tue, 2 Nov 2021 15:25:32 +0100
Subject: [PATCH] upload to b2share, basic implementation

---
 dags/b2shareoperator.py       |  4 +-
 dags/taskflow.py              |  2 +
 dags/uploadflow.py            | 73 +++++++++++++++++++++++++++++++++++
 tests/test_b2shareoperator.py |  2 +-
 4 files changed, 78 insertions(+), 3 deletions(-)
 create mode 100644 dags/uploadflow.py

diff --git a/dags/b2shareoperator.py b/dags/b2shareoperator.py
index a509358..31f071a 100644
--- a/dags/b2shareoperator.py
+++ b/dags/b2shareoperator.py
@@ -48,8 +48,8 @@ def create_draft_record(server: str, token: str, record):
     return response.json()
 
 # the simplest version, target should be chunked
-def add_file(record, fname: str, token: str):
-    jf = os.path.split(fname)[-1]
+def add_file(record, fname: str, token: str, remote:str):
+    jf = os.path.split(remote)[-1]
     return requests.put(url=f"{record['links']['files']}/{jf}",
                          params={'access_token': token},
                          headers={"Content-Type":"application/octet-stream"},
diff --git a/dags/taskflow.py b/dags/taskflow.py
index d926e9c..c0153ed 100644
--- a/dags/taskflow.py
+++ b/dags/taskflow.py
@@ -58,6 +58,8 @@ def taskflow_example():
             for [truename, local] in files.items():
                 print(f"Copying {local} --> {connection_id}:{os.path.join(target, truename)}")
                 sftp_client.put(local, os.path.join(target, truename))
+                # or separate cleanup task?
+                os.unlink(local)
 
     data = extract()
     files = transform(data)
diff --git a/dags/uploadflow.py b/dags/uploadflow.py
new file mode 100644
index 0000000..5f99de9
--- /dev/null
+++ b/dags/uploadflow.py
@@ -0,0 +1,73 @@
+
+import os
+import tempfile
+
+from airflow.decorators import dag, task
+from airflow.models.connection import Connection
+from airflow.providers.ssh.hooks.ssh import SSHHook
+from airflow.utils.dates import days_ago
+
+from b2shareoperator import (add_file, create_draft_record,
+                             get_record_template, submit_draft)
+
+
+default_args = {
+    'owner': 'airflow',
+}
+
+
+@dag(default_args=default_args, schedule_interval=None, start_date=days_ago(2), tags=['example'])
+def upload_example():
+    @task()
+    def load(**kwargs):
+        params = kwargs['params']
+        target = params.get('target', '/tmp/')
+        source = params.get('source', '/tmp/')
+        connection_id = params.get('connection', 'default_ssh')
+        
+        ssh_hook = SSHHook(ssh_conn_id=connection_id)
+        with ssh_hook.get_conn() as ssh_client:
+            sftp_client = ssh_client.open_sftp()
+            lst = sftp_client.listdir(path=source)
+            mappings = dict()
+            for fname in lst:
+                local = tempfile.mktemp(prefix='dls', dir=target)
+                full_name = os.path.join(source, fname)
+                sts = sftp_client.stat(full_name)
+                if str(sts).startswith('d'):
+                    print(f"{full_name} is a directory. Skipping")
+                    continue
+
+                print(f"Copying {connection_id}:{full_name} --> {local}")
+                sftp_client.get(os.path.join(source, fname), local)
+                mappings[local] = fname
+
+        return mappings
+        
+
+    @task()
+    def upload(files: dict, **kwargs):
+        connection = Connection.get_connection_from_secrets('default_b2share')
+        # hate such hacks: 
+        server = "https://" + connection.host
+        token = connection.extra_dejson['access_token']
+        print(f"Server: {server} + {token}")
+        
+        template = get_record_template()
+        r = create_draft_record(server=server, token=token, record=template)
+        print(r)
+        print(f"Draft record created {r['id']} --> {r['links']['self']}")
+
+        for [local, true_name] in files.items():
+            print(f"Uploading {local} --> {true_name}")
+            up = add_file(record=r, fname=local, token=token, remote=true_name)
+
+        print("Submitting record for pubication")
+        submitted = submit_draft(record=r, token=token)
+        print(f"Record created {submitted['id']}")
+
+    files = load()
+    upload(files)
+
+
+dag = upload_example()
diff --git a/tests/test_b2shareoperator.py b/tests/test_b2shareoperator.py
index 9ecebf3..f873cc0 100644
--- a/tests/test_b2shareoperator.py
+++ b/tests/test_b2shareoperator.py
@@ -99,7 +99,7 @@ class B2ShareOperatorTest(unittest.TestCase):
         with patch('dags.b2shareoperator.requests.post') as put:
             a = tempfile.NamedTemporaryFile()
             a.write(b"some content")
-            up = add_file(record=r, fname=a.name, token=token)
+            up = add_file(record=r, fname=a.name, token=token, remote='/tmp/somefile.txt')
 
 
         with patch('dags.b2shareoperator.requests.patch') as p:
-- 
GitLab