From 5ccb37b5603457b8df8ba471e68ed77898bb38a1 Mon Sep 17 00:00:00 2001
From: jrybicki-jsc <j.rybicki@fz-juelich.de>
Date: Wed, 23 Mar 2022 11:14:30 +0100
Subject: [PATCH] directory making for upload, and removal temp files for
 registration

---
 dags/taskflow.py   | 2 ++
 dags/uploadflow.py | 6 +++++-
 2 files changed, 7 insertions(+), 1 deletion(-)

diff --git a/dags/taskflow.py b/dags/taskflow.py
index cd9b05c..3e18ad6 100644
--- a/dags/taskflow.py
+++ b/dags/taskflow.py
@@ -62,6 +62,8 @@ def taskflow_example():
 
         with ssh_hook.get_conn() as ssh_client:
             sftp_client = ssh_client.open_sftp()
+            # check dir?
+            sftp_client.mkdir(target, ignore_existing=True)
             for [truename, local] in files.items():
                 print(
                     f"Copying {local} --> {connection_id}:{os.path.join(target, truename)}")
diff --git a/dags/uploadflow.py b/dags/uploadflow.py
index 7feb46d..e38a40f 100644
--- a/dags/uploadflow.py
+++ b/dags/uploadflow.py
@@ -7,6 +7,7 @@ from airflow.models.connection import Connection
 from airflow.operators.python import PythonOperator
 from airflow.providers.http.hooks.http import HttpHook
 from airflow.utils.dates import days_ago
+from airflow.models import Variable
 
 from b2shareoperator import (add_file, create_draft_record, get_community,
                              submit_draft)
@@ -41,7 +42,7 @@ def upload_example():
     @task()
     def load(connection_id, **kwargs):
         params = kwargs['params']
-        target = params.get('target', '/tmp/')
+        target = Variable.get("working_dir", default_var='/tmp/')
         source = params.get('source', '/tmp/')
 
         ssh_hook = get_connection(conn_id=connection_id, **kwargs)
@@ -93,10 +94,13 @@ def upload_example():
         for [local, true_name] in files.items():
             print(f"Uploading {local} --> {true_name}")
             _ = add_file(record=r, fname=local, token=token, remote=true_name)
+            # delete local
+            os.unlink(local)
 
         print("Submitting record for pubication")
         submitted = submit_draft(record=r, token=token)
         print(f"Record created {submitted['id']}")
+
         return submitted['id']
 
 
-- 
GitLab