Skip to content
Snippets Groups Projects
Commit acd48ea0 authored by Maria Petrova-El Sayed's avatar Maria Petrova-El Sayed
Browse files

Merge branch 'main' into mptest

parents f0780244 ea9d985d
No related branches found
No related tags found
No related merge requests found
Pipeline #82471 passed
import json
import os
import tempfile
import urllib
from urllib.parse import urljoin
......@@ -29,6 +31,37 @@ def download_file(url: str, target_dir: str):
urllib.request.urlretrieve(url=url, filename=fname)
return fname
def get_record_template():
return {"titles":[{"title":"DLS dataset record"}],
"creators":[{"creator_name": "eflows4HPC"}],
"descriptions":
[{"description": "Output of eflows4HPC DLS", "description_type": "Abstract"}],
"community": "a9217684-945b-4436-8632-cac271f894ed",
'community_specific':
{'91ae5d2a-3848-4693-9f7d-cbd141172ef0': {'helmholtz centre': ['Forschungszentrum Jülich']}},
"open_access": True}
def create_draft_record(server: str, token: str, record):
response = requests.post( url=urljoin(server, 'api/records/'),
headers={'Content-Type':'application/json'},
data=json.dumps(record), params={'access_token': token})
return response.json()
# the simplest version, target should be chunked
def add_file(record, fname: str, token: str, remote:str):
jf = os.path.split(remote)[-1]
return requests.put(url=f"{record['links']['files']}/{jf}",
params={'access_token': token},
headers={"Content-Type":"application/octet-stream"},
data=open(fname, 'rb'))
def submit_draft(record, token):
pub = [{"op": "add", "path":"/publication_state", "value": "submitted"}]
response = requests.patch(record['links']['self'],
headers={"Content-Type":"application/json-patch+json"},
data=json.dumps(pub), params={'access_token': token})
return response.json()
class B2ShareOperator(BaseOperator):
template_fields = ('target_dir',)
......
......@@ -50,13 +50,16 @@ def taskflow_example():
print(f"Total files downloaded: {len(files)}")
params = kwargs['params']
target = params.get('target', '/tmp/')
connection_id = params.get('connection', 'default_ssh')
ssh_hook = SSHHook(ssh_conn_id='default_ssh')
ssh_hook = SSHHook(ssh_conn_id=connection_id)
with ssh_hook.get_conn() as ssh_client:
sftp_client = ssh_client.open_sftp()
for [truename, local] in files.items():
print(f"Copying {local} --> {os.path.join(target, truename)}")
print(f"Copying {local} --> {connection_id}:{os.path.join(target, truename)}")
sftp_client.put(local, os.path.join(target, truename))
# or separate cleanup task?
os.unlink(local)
data = extract()
files = transform(data)
......
import os
import tempfile
from airflow.decorators import dag, task
from airflow.models.connection import Connection
from airflow.providers.ssh.hooks.ssh import SSHHook
from airflow.utils.dates import days_ago
from b2shareoperator import (add_file, create_draft_record,
get_record_template, submit_draft)
default_args = {
'owner': 'airflow',
}
@dag(default_args=default_args, schedule_interval=None, start_date=days_ago(2), tags=['example'])
def upload_example():
@task()
def load(**kwargs):
params = kwargs['params']
target = params.get('target', '/tmp/')
source = params.get('source', '/tmp/')
connection_id = params.get('connection', 'default_ssh')
ssh_hook = SSHHook(ssh_conn_id=connection_id)
with ssh_hook.get_conn() as ssh_client:
sftp_client = ssh_client.open_sftp()
lst = sftp_client.listdir(path=source)
mappings = dict()
for fname in lst:
local = tempfile.mktemp(prefix='dls', dir=target)
full_name = os.path.join(source, fname)
sts = sftp_client.stat(full_name)
if str(sts).startswith('d'):
print(f"{full_name} is a directory. Skipping")
continue
print(f"Copying {connection_id}:{full_name} --> {local}")
sftp_client.get(os.path.join(source, fname), local)
mappings[local] = fname
return mappings
@task()
def upload(files: dict, **kwargs):
connection = Connection.get_connection_from_secrets('default_b2share')
# hate such hacks:
server = "https://" + connection.host
token = connection.extra_dejson['access_token']
print(f"Server: {server} + {token}")
template = get_record_template()
r = create_draft_record(server=server, token=token, record=template)
print(r)
print(f"Draft record created {r['id']} --> {r['links']['self']}")
for [local, true_name] in files.items():
print(f"Uploading {local} --> {true_name}")
up = add_file(record=r, fname=local, token=token, remote=true_name)
print("Submitting record for pubication")
submitted = submit_draft(record=r, token=token)
print(f"Record created {submitted['id']}")
files = load()
upload(files)
dag = upload_example()
......@@ -20,7 +20,7 @@ curl -X POST -u USER:PASS -H "Content-Type: application/json" \
There should be an object created in B2SHARE, each object in B2SHARE is identified by a +id+, which needs to be passed to the DLS workflow as a parameter (see below).
Also a connection to the SSH target needs to be created (where the data will be copied to):
Also a connection to the SSH target needs to be created (where the data will be copied to). You are free to use own conneciton name here, and pass it over to during the execution (defult is 'default_ssh')
----
curl -X POST -u USER:PASS -H "Content-Type: application/json" \
......@@ -44,6 +44,15 @@ curl -X POST -u USER:PASS -H "Content-Type: application/json" \
$DLS/dags/taskflow_example/dagRuns
----
If you want to use your own connection ('myown_con'):
----
curl -X POST -u USER:PASS -H "Content-Type: application/json" \
--data '{"conf": {"oid": ID}, "target": PATH, "connection": "myown_con"}' \
$DLS/dags/taskflow_example/dagRuns
----
=== Checking status ===
----
curl -X GET -u USER:PASS -H "Content-Type: application/json" $DLS/dags/taskflow_example/dagRuns
......
import unittest
from unittest.mock import Mock, patch
import tempfile
import os
from airflow import DAG
from airflow.models.taskinstance import TaskInstance
......@@ -7,7 +9,8 @@ from airflow.utils.dates import days_ago
from airflow.utils.state import State
from dags.b2shareoperator import (B2ShareOperator, download_file,
get_file_list, get_object_md, get_objects)
get_file_list, get_object_md, get_objects,
get_record_template, create_draft_record, add_file, submit_draft)
DEFAULT_DATE = '2019-10-03'
TEST_DAG_ID = 'test_my_custom_operator'
......@@ -83,3 +86,22 @@ class B2ShareOperatorTest(unittest.TestCase):
get.return_value = m
r = get_objects(server='foo')
self.assertListEqual(['a', 'b'], r)
def test_upload(self):
template = get_record_template()
server='https://b2share-testing.fz-juelich.de/'
token = ''
with patch('dags.b2shareoperator.requests.post') as post:
r = create_draft_record(server=server, token=token, record=template)
r = dict()
r['links']={'files':server, 'self': server}
with patch('dags.b2shareoperator.requests.post') as put:
a = tempfile.NamedTemporaryFile()
a.write(b"some content")
up = add_file(record=r, fname=a.name, token=token, remote='/tmp/somefile.txt')
with patch('dags.b2shareoperator.requests.patch') as p:
submitted = submit_draft(record=r, token=token)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment