Skip to content
Snippets Groups Projects
Commit 31481e74 authored by Jedrzej Rybicki's avatar Jedrzej Rybicki
Browse files

check testing

parent d1d0fe97
No related branches found
No related tags found
No related merge requests found
from datetime import timedelta
from airflow import settings
from airflow.decorators import dag, task
from airflow.providers.ssh.hooks.ssh import SSHHook
from airflow.models.connection import Connection
from airflow.utils.dates import days_ago
from airflow.providers.hashicorp.hooks.vault import VaultHook
from airflow.operators.python import PythonOperator
from airflow.utils.dates import days_ago
from decors import get_connection, remove, setup
def_args = {
'owner': 'airflow',
......@@ -18,89 +16,12 @@ def_args = {
}
def create_temp_connection(rrid, params):
host = params.get('host')
port = params.get('port', 2222)
user = params.get('login', 'eflows')
key = params.get('key')
conn_id = f"tmp_connection_{rrid}"
extra = {"private_key": key}
conn = Connection(
conn_id=conn_id,
conn_type='ssh',
description='Automatically generated Connection',
host=host,
login=user,
port=port,
extra=extra
)
session = settings.Session()
session.add(conn)
session.commit()
print(f"Connection {conn_id} created")
return conn_id
def get_connection(conn_id):
if conn_id.startswith('vault'):
vault_hook = VaultHook(vault_conn_id='my_vault')
con = vault_hook.get_secret(secret_path=f"ssh-credentials/{conn_id[6:]}")
print(f"Got some values from vault {list(con.keys())}")
# for now SSH is hardcoded
host = con.get('host', 'bsc')
port = int(con.get('port', 22))
hook = SSHHook(remote_host=host, port=port, username=con['userName'])
#key in vault should be in form of formated string:
#-----BEGIN OPENSSH PRIVATE KEY-----
#b3BlbnNzaC1rZXktdjEAAAAA
#....
hook.pkey = hook._pkey_from_private_key(private_key=con['privateKey'])
return hook
# otherwise use previously created temp connection
return SSHHook(ssh_conn_id=conn_id)
def setup(**kwargs):
params = kwargs['params']
print(f"Setting up the connection", params )
if 'vault_id' in params:
print('Retrieving connection details from vault')
return f"vault_{params['vault_id']}"
# otherwise use creds provided in request
return create_temp_connection(rrid = kwargs['run_id'], params=params)
def remove(conn_id):
if conn_id.startswith('vault'):
return
print(f"Removing conneciton {conn_id}")
session = settings.Session()
for con in session.query(Connection).all():
print(con)
session.query(Connection).filter(Connection.conn_id == conn_id).delete()
session.commit()
def get_conn_id(**kwargs):
ti = kwargs['ti']
conn_id = ti.xcom_pull(key='return_value', task_ids='setup_connection')
return conn_id
@dag(default_args=def_args, schedule_interval=None, start_date=days_ago(2), tags=['example'])
def conn_decorator():
@task()
def doing_nothing(a_id, **kwargs):
conn_id = get_conn_id(**kwargs)
print(f"Just doing nothing with {conn_id}")
print(f"I also got ", a_id)
def doing_nothing(conn_id):
print(f"Using connection {conn_id}")
ssh_hook = get_connection(conn_id=conn_id)
with ssh_hook.get_conn() as ssh_client:
......@@ -115,8 +36,9 @@ def conn_decorator():
conn_id = PythonOperator(python_callable=setup, task_id='setup_connection')
# another way of mixing taskflow and classical api:
a_id = conn_id.output['return_value']
dno = doing_nothing(a_id=a_id)
en = PythonOperator(python_callable=remove, op_kwargs={'conn_id': dno}, task_id='cleanup')
dno = doing_nothing(conn_id=a_id)
en = PythonOperator(python_callable=remove, op_kwargs={
'conn_id': dno}, task_id='cleanup')
conn_id >> dno >> en
......
from airflow import settings
from airflow.providers.ssh.hooks.ssh import SSHHook
from airflow.models.connection import Connection
from airflow.providers.hashicorp.hooks.vault import VaultHook
def create_temp_connection(rrid, params):
host = params.get('host')
port = params.get('port', 2222)
user = params.get('login', 'eflows')
key = params.get('key')
conn_id = f"tmp_connection_{rrid}"
extra = {"private_key": key}
conn = Connection(
conn_id=conn_id,
conn_type='ssh',
description='Automatically generated Connection',
host=host,
login=user,
port=port,
extra=extra
)
session = settings.Session()
session.add(conn)
session.commit()
print(f"Connection {conn_id} created")
return conn_id
def get_connection(conn_id):
if conn_id.startswith('vault'):
vault_hook = VaultHook(vault_conn_id='my_vault')
con = vault_hook.get_secret(
secret_path=f"ssh-credentials/{conn_id[6:]}")
print(f"Got some values from vault {list(con.keys())}")
# for now SSH is hardcoded
host = con.get('host', 'bsc')
port = int(con.get('port', 22))
hook = SSHHook(remote_host=host, port=port, username=con['userName'])
# key in vault should be in form of formated string:
# -----BEGIN OPENSSH PRIVATE KEY-----
# b3BlbnNzaC1rZXktdjEAAAAA
# ....
hook.pkey = hook._pkey_from_private_key(private_key=con['privateKey'])
return hook
# otherwise use previously created temp connection
return SSHHook(ssh_conn_id=conn_id)
def setup(**kwargs):
params = kwargs['params']
print("Setting up the connection")
if 'vault_id' in params:
print('Retrieving connection details from vault')
return f"vault_{params['vault_id']}"
# otherwise use creds provided in request
return create_temp_connection(rrid=kwargs['run_id'], params=params)
def remove(conn_id):
if conn_id.startswith('vault'):
return
print(f"Removing conneciton {conn_id}")
session = settings.Session()
for con in session.query(Connection).all():
print(con)
session.query(Connection).filter(Connection.conn_id == conn_id).delete()
session.commit()
def get_conn_id(**kwargs):
ti = kwargs['ti']
conn_id = ti.xcom_pull(key='return_value', task_ids='setup_connection')
return conn_id
......@@ -2,51 +2,42 @@
import os
import tempfile
from airflow import settings
from airflow.decorators import dag, task
from airflow.models.connection import Connection
from airflow.operators.python import PythonOperator
from airflow.providers.http.hooks.http import HttpHook
from airflow.providers.ssh.hooks.ssh import SSHHook
from airflow.utils.dates import days_ago
from b2shareoperator import (add_file, create_draft_record,
get_community, submit_draft)
from b2shareoperator import (add_file, create_draft_record, get_community,
submit_draft)
from decors import remove, setup
default_args = {
'owner': 'airflow',
}
def create_template(hrespo):
return {
"titles": [{"title": hrespo['title']}],
"creators": [{"creator_name": hrespo['creator_name']}],
"descriptions": [
{
"description": hrespo['description'],
"description_type": "Abstract"
}
],
"community": "2d58eb08-af65-4cad-bd25-92f1a17d325b",
"community_specific": {
"90942261-4637-4ac0-97b8-12e1edb38739": {"helmholtz centre": ["Forschungszentrum Jülich"]}
},
"open_access": hrespo['open_access'] == "True"
}
@dag(default_args=default_args, schedule_interval=None, start_date=days_ago(2), tags=['example'])
def upload_example():
@task
def setup(**kwargs):
print(f"Setting up the connection")
params = kwargs['params']
rrid = kwargs['run_id']
host = params.get('host')
port = params.get('port', 2222)
user = params.get('login', 'eflows')
key = params.get('key')
conn_id = f"tmp_connection_{rrid}"
extra = {"private_key": key}
conn = Connection(
conn_id=conn_id,
conn_type='ssh',
description='Automatically generated Connection',
host=host,
login=user,
port=port,
extra=extra
)
session = settings.Session()
session.add(conn)
session.commit()
print(f"Connection {conn_id} created")
return conn_id
@task()
def load(connection_id, **kwargs):
......@@ -73,25 +64,6 @@ def upload_example():
return mappings
def create_template(hrespo):
return {
"titles" : [{"title": hrespo['title']}],
"creators" : [{"creator_name": hrespo['creator_name']}],
"descriptions" :[
{
"description": hrespo['description'],
"description_type": "Abstract"
}
],
"community" : "2d58eb08-af65-4cad-bd25-92f1a17d325b",
"community_specific" :{
"90942261-4637-4ac0-97b8-12e1edb38739": {"helmholtz centre": ["Forschungszentrum Jülich"]}
},
"open_access": hrespo['open_access']=="True"
}
@task()
def upload(files: dict, **kwargs):
connection = Connection.get_connection_from_secrets('default_b2share')
......@@ -99,7 +71,6 @@ def upload_example():
server = "https://" + connection.host
token = connection.extra_dejson['access_token']
params = kwargs['params']
mid = params['mid']
......@@ -107,7 +78,8 @@ def upload_example():
hrespo = hook.run(endpoint=f"storage_target/{mid}").json()['metadata']
print(hrespo)
template = create_template(hrespo=hrespo)
community = get_community(server=server, community_id=template['community'])
community = get_community(
server=server, community_id=template['community'])
if not community:
print("Not existing community")
return
......@@ -116,35 +88,31 @@ def upload_example():
if any(missing):
print(f"Community {cid} required field {missing} are missing. This could pose some problems")
r = create_draft_record(server=server, token=token, record=template)
print(r)
print(f"Draft record created {r['id']} --> {r['links']['self']}")
for [local, true_name] in files.items():
print(f"Uploading {local} --> {true_name}")
up = add_file(record=r, fname=local, token=token, remote=true_name)
_ = add_file(record=r, fname=local, token=token, remote=true_name)
print("Submitting record for pubication")
submitted = submit_draft(record=r, token=token)
print(f"Record created {submitted['id']}")
return submitted['id']
@task()
def remove(conn_id, uid):
print(f"Upload {uid} completed. Removing conneciton {conn_id}")
session = settings.Session()
for con in session.query(Connection).all():
print(con)
session.query(Connection).filter(Connection.conn_id == conn_id).delete()
session.commit()
conn_id = setup()
files = load(connection_id=conn_id)
setup_task = PythonOperator(
python_callable=setup, task_id='setup_connection')
a_id = setup_task.output['return_value']
files = load(connection_id=a_id)
uid = upload(files)
remove(conn_id=conn_id, uid=uid)
en = PythonOperator(python_callable=remove, op_kwargs={
'conn_id': a_id}, task_id='cleanup')
setup_task >> files >> uid >> en
dag = upload_example()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment