diff --git a/dags/conn_deco.py b/dags/conn_deco.py index 89935fdff02d0e5aab85225c5704cc88e443aaf7..4d1fe66990b4359b85882144c1c0fe7914fdfcea 100644 --- a/dags/conn_deco.py +++ b/dags/conn_deco.py @@ -1,12 +1,10 @@ from datetime import timedelta -from airflow import settings from airflow.decorators import dag, task -from airflow.providers.ssh.hooks.ssh import SSHHook -from airflow.models.connection import Connection -from airflow.utils.dates import days_ago -from airflow.providers.hashicorp.hooks.vault import VaultHook from airflow.operators.python import PythonOperator +from airflow.utils.dates import days_ago + +from decors import get_connection, remove, setup def_args = { 'owner': 'airflow', @@ -18,90 +16,13 @@ def_args = { } -def create_temp_connection(rrid, params): - host = params.get('host') - port = params.get('port', 2222) - user = params.get('login', 'eflows') - key = params.get('key') - - conn_id = f"tmp_connection_{rrid}" - extra = {"private_key": key} - conn = Connection( - conn_id=conn_id, - conn_type='ssh', - description='Automatically generated Connection', - host=host, - login=user, - port=port, - extra=extra - ) - - session = settings.Session() - session.add(conn) - session.commit() - print(f"Connection {conn_id} created") - return conn_id - -def get_connection(conn_id): - if conn_id.startswith('vault'): - vault_hook = VaultHook(vault_conn_id='my_vault') - con = vault_hook.get_secret(secret_path=f"ssh-credentials/{conn_id[6:]}") - print(f"Got some values from vault {list(con.keys())}") - - # for now SSH is hardcoded - host = con.get('host', 'bsc') - port = int(con.get('port', 22)) - hook = SSHHook(remote_host=host, port=port, username=con['userName']) - #key in vault should be in form of formated string: - #-----BEGIN OPENSSH PRIVATE KEY----- - #b3BlbnNzaC1rZXktdjEAAAAA - #.... - hook.pkey = hook._pkey_from_private_key(private_key=con['privateKey']) - return hook - - # otherwise use previously created temp connection - return SSHHook(ssh_conn_id=conn_id) - - -def setup(**kwargs): - params = kwargs['params'] - print(f"Setting up the connection", params ) - - - if 'vault_id' in params: - print('Retrieving connection details from vault') - return f"vault_{params['vault_id']}" - - # otherwise use creds provided in request - return create_temp_connection(rrid = kwargs['run_id'], params=params) - -def remove(conn_id): - if conn_id.startswith('vault'): - return - - print(f"Removing conneciton {conn_id}") - session = settings.Session() - for con in session.query(Connection).all(): - print(con) - - session.query(Connection).filter(Connection.conn_id == conn_id).delete() - session.commit() - -def get_conn_id(**kwargs): - ti = kwargs['ti'] - conn_id = ti.xcom_pull(key='return_value', task_ids='setup_connection') - return conn_id - @dag(default_args=def_args, schedule_interval=None, start_date=days_ago(2), tags=['example']) def conn_decorator(): - - + @task() - def doing_nothing(a_id, **kwargs): - conn_id = get_conn_id(**kwargs) - print(f"Just doing nothing with {conn_id}") - print(f"I also got ", a_id) - + def doing_nothing(conn_id): + print(f"Using connection {conn_id}") + ssh_hook = get_connection(conn_id=conn_id) with ssh_hook.get_conn() as ssh_client: sftp_client = ssh_client.open_sftp() @@ -115,8 +36,9 @@ def conn_decorator(): conn_id = PythonOperator(python_callable=setup, task_id='setup_connection') # another way of mixing taskflow and classical api: a_id = conn_id.output['return_value'] - dno = doing_nothing(a_id=a_id) - en = PythonOperator(python_callable=remove, op_kwargs={'conn_id': dno}, task_id='cleanup') + dno = doing_nothing(conn_id=a_id) + en = PythonOperator(python_callable=remove, op_kwargs={ + 'conn_id': dno}, task_id='cleanup') conn_id >> dno >> en diff --git a/dags/decors.py b/dags/decors.py new file mode 100644 index 0000000000000000000000000000000000000000..a02a7c3eefac948f7390c110fdc5169b393205ad --- /dev/null +++ b/dags/decors.py @@ -0,0 +1,82 @@ +from airflow import settings +from airflow.providers.ssh.hooks.ssh import SSHHook +from airflow.models.connection import Connection +from airflow.providers.hashicorp.hooks.vault import VaultHook + + +def create_temp_connection(rrid, params): + host = params.get('host') + port = params.get('port', 2222) + user = params.get('login', 'eflows') + key = params.get('key') + + conn_id = f"tmp_connection_{rrid}" + extra = {"private_key": key} + conn = Connection( + conn_id=conn_id, + conn_type='ssh', + description='Automatically generated Connection', + host=host, + login=user, + port=port, + extra=extra + ) + + session = settings.Session() + session.add(conn) + session.commit() + print(f"Connection {conn_id} created") + return conn_id + + +def get_connection(conn_id): + if conn_id.startswith('vault'): + vault_hook = VaultHook(vault_conn_id='my_vault') + con = vault_hook.get_secret( + secret_path=f"ssh-credentials/{conn_id[6:]}") + print(f"Got some values from vault {list(con.keys())}") + + # for now SSH is hardcoded + host = con.get('host', 'bsc') + port = int(con.get('port', 22)) + hook = SSHHook(remote_host=host, port=port, username=con['userName']) + # key in vault should be in form of formated string: + # -----BEGIN OPENSSH PRIVATE KEY----- + # b3BlbnNzaC1rZXktdjEAAAAA + # .... + hook.pkey = hook._pkey_from_private_key(private_key=con['privateKey']) + return hook + + # otherwise use previously created temp connection + return SSHHook(ssh_conn_id=conn_id) + + +def setup(**kwargs): + params = kwargs['params'] + print("Setting up the connection") + + if 'vault_id' in params: + print('Retrieving connection details from vault') + return f"vault_{params['vault_id']}" + + # otherwise use creds provided in request + return create_temp_connection(rrid=kwargs['run_id'], params=params) + + +def remove(conn_id): + if conn_id.startswith('vault'): + return + + print(f"Removing conneciton {conn_id}") + session = settings.Session() + for con in session.query(Connection).all(): + print(con) + + session.query(Connection).filter(Connection.conn_id == conn_id).delete() + session.commit() + + +def get_conn_id(**kwargs): + ti = kwargs['ti'] + conn_id = ti.xcom_pull(key='return_value', task_ids='setup_connection') + return conn_id diff --git a/dags/uploadflow.py b/dags/uploadflow.py index 1db56234655d4dd8004e2a3b43313111504ec2b0..454309e48e60c3441cca266f1c9e3c7bd48b1862 100644 --- a/dags/uploadflow.py +++ b/dags/uploadflow.py @@ -2,58 +2,49 @@ import os import tempfile -from airflow import settings from airflow.decorators import dag, task from airflow.models.connection import Connection +from airflow.operators.python import PythonOperator from airflow.providers.http.hooks.http import HttpHook from airflow.providers.ssh.hooks.ssh import SSHHook from airflow.utils.dates import days_ago -from b2shareoperator import (add_file, create_draft_record, - get_community, submit_draft) +from b2shareoperator import (add_file, create_draft_record, get_community, + submit_draft) +from decors import remove, setup default_args = { 'owner': 'airflow', } +def create_template(hrespo): + return { + "titles": [{"title": hrespo['title']}], + "creators": [{"creator_name": hrespo['creator_name']}], + "descriptions": [ + { + "description": hrespo['description'], + "description_type": "Abstract" + } + ], + "community": "2d58eb08-af65-4cad-bd25-92f1a17d325b", + "community_specific": { + "90942261-4637-4ac0-97b8-12e1edb38739": {"helmholtz centre": ["Forschungszentrum Jülich"]} + }, + "open_access": hrespo['open_access'] == "True" + } + + @dag(default_args=default_args, schedule_interval=None, start_date=days_ago(2), tags=['example']) def upload_example(): - @task - def setup(**kwargs): - print(f"Setting up the connection") - - params = kwargs['params'] - rrid = kwargs['run_id'] - host = params.get('host') - port = params.get('port', 2222) - user = params.get('login', 'eflows') - key = params.get('key') - - conn_id = f"tmp_connection_{rrid}" - extra = {"private_key": key} - conn = Connection( - conn_id=conn_id, - conn_type='ssh', - description='Automatically generated Connection', - host=host, - login=user, - port=port, - extra=extra - ) - - session = settings.Session() - session.add(conn) - session.commit() - print(f"Connection {conn_id} created") - return conn_id @task() def load(connection_id, **kwargs): params = kwargs['params'] target = params.get('target', '/tmp/') source = params.get('source', '/tmp/') - + ssh_hook = SSHHook(ssh_conn_id=connection_id) with ssh_hook.get_conn() as ssh_client: sftp_client = ssh_client.open_sftp() @@ -72,34 +63,14 @@ def upload_example(): mappings[local] = fname return mappings - - def create_template(hrespo): - return { - "titles" : [{"title": hrespo['title']}], - "creators" : [{"creator_name": hrespo['creator_name']}], - "descriptions" :[ - { - "description": hrespo['description'], - "description_type": "Abstract" - } - ], - "community" : "2d58eb08-af65-4cad-bd25-92f1a17d325b", - "community_specific" :{ - "90942261-4637-4ac0-97b8-12e1edb38739": {"helmholtz centre": ["Forschungszentrum Jülich"]} - }, - "open_access": hrespo['open_access']=="True" - } - - @task() def upload(files: dict, **kwargs): connection = Connection.get_connection_from_secrets('default_b2share') - # hate such hacks: + # hate such hacks: server = "https://" + connection.host token = connection.extra_dejson['access_token'] - params = kwargs['params'] mid = params['mid'] @@ -107,7 +78,8 @@ def upload_example(): hrespo = hook.run(endpoint=f"storage_target/{mid}").json()['metadata'] print(hrespo) template = create_template(hrespo=hrespo) - community = get_community(server=server, community_id=template['community']) + community = get_community( + server=server, community_id=template['community']) if not community: print("Not existing community") return @@ -115,36 +87,32 @@ def upload_example(): missing = [r for r in required if r not in template] if any(missing): print(f"Community {cid} required field {missing} are missing. This could pose some problems") - - r = create_draft_record(server=server, token=token, record=template) - print(r) print(f"Draft record created {r['id']} --> {r['links']['self']}") for [local, true_name] in files.items(): print(f"Uploading {local} --> {true_name}") - up = add_file(record=r, fname=local, token=token, remote=true_name) + _ = add_file(record=r, fname=local, token=token, remote=true_name) print("Submitting record for pubication") submitted = submit_draft(record=r, token=token) print(f"Record created {submitted['id']}") return submitted['id'] - @task() - def remove(conn_id, uid): - print(f"Upload {uid} completed. Removing conneciton {conn_id}") - session = settings.Session() - for con in session.query(Connection).all(): - print(con) - session.query(Connection).filter(Connection.conn_id == conn_id).delete() - session.commit() - conn_id = setup() - files = load(connection_id=conn_id) + setup_task = PythonOperator( + python_callable=setup, task_id='setup_connection') + a_id = setup_task.output['return_value'] + + files = load(connection_id=a_id) uid = upload(files) - remove(conn_id=conn_id, uid=uid) + + en = PythonOperator(python_callable=remove, op_kwargs={ + 'conn_id': a_id}, task_id='cleanup') + + setup_task >> files >> uid >> en dag = upload_example()