Skip to content
Snippets Groups Projects
Commit b6b81e31 authored by Jedrzej Rybicki's avatar Jedrzej Rybicki
Browse files

wip: towards more flexible conn management

parent 506e73c8
No related branches found
No related tags found
No related merge requests found
...@@ -6,6 +6,7 @@ from airflow.providers.ssh.hooks.ssh import SSHHook ...@@ -6,6 +6,7 @@ from airflow.providers.ssh.hooks.ssh import SSHHook
from airflow.models.connection import Connection from airflow.models.connection import Connection
from airflow.utils.dates import days_ago from airflow.utils.dates import days_ago
from airflow.providers.hashicorp.hooks.vault import VaultHook from airflow.providers.hashicorp.hooks.vault import VaultHook
from airflow.operators.python import PythonOperator
def_args = { def_args = {
'owner': 'airflow', 'owner': 'airflow',
...@@ -15,6 +16,8 @@ def_args = { ...@@ -15,6 +16,8 @@ def_args = {
'retries': 1, 'retries': 1,
'retry_delay': timedelta(minutes=5) 'retry_delay': timedelta(minutes=5)
} }
def create_temp_connection(rrid, params): def create_temp_connection(rrid, params):
host = params.get('host') host = params.get('host')
port = params.get('port', 2222) port = params.get('port', 2222)
...@@ -42,44 +45,63 @@ def create_temp_connection(rrid, params): ...@@ -42,44 +45,63 @@ def create_temp_connection(rrid, params):
def get_connection(conn_id): def get_connection(conn_id):
if conn_id.startswith('vault'): if conn_id.startswith('vault'):
vault_hook = VaultHook(vault_conn_id='my_vault') vault_hook = VaultHook(vault_conn_id='my_vault')
con = vault_hook.get_secret(secret_path=f"connections/{conn_id[6:]}") con = vault_hook.get_secret(secret_path=f"ssh-credentials/{conn_id[6:]}")
print(f"Got some values from vault {list(con.keys())}") print(f"Got some values from vault {list(con.keys())}")
# for now SSH is hardcoded # for now SSH is hardcoded
hook = SSHHook(remote_host=con['host'], port=int(con['port']), username=con['login']) host = con.get('host', 'bsc')
port = int(con.get('port', 22))
hook = SSHHook(remote_host=host, port=port, username=con['userName'])
#key in vault should be in form of formated string: #key in vault should be in form of formated string:
#-----BEGIN OPENSSH PRIVATE KEY----- #-----BEGIN OPENSSH PRIVATE KEY-----
#b3BlbnNzaC1rZXktdjEAAAAA #b3BlbnNzaC1rZXktdjEAAAAA
#.... #....
hook.pkey = hook._pkey_from_private_key(private_key=con['key']) hook.pkey = hook._pkey_from_private_key(private_key=con['privateKey'])
return hook return hook
# otherwise use previously created temp connection # otherwise use previously created temp connection
return SSHHook(ssh_conn_id=conn_id) return SSHHook(ssh_conn_id=conn_id)
def setup(**kwargs):
@dag(default_args=def_args, schedule_interval=None, start_date=days_ago(2), tags=['example']) params = kwargs['params']
def conn_decorator(): print(f"Setting up the connection", params )
@task
def setup(**kwargs):
print(f"Setting up the connection")
params = kwargs['params']
if 'vault_id' in params: if 'vault_id' in params:
print('Retrieving connection details from vault') print('Retrieving connection details from vault')
return f"vault_{params['vault_id']}" return f"vault_{params['vault_id']}"
# otherwise use creds provided in request # otherwise use creds provided in request
return create_temp_connection(rrid = kwargs['run_id'], parms=params) return create_temp_connection(rrid = kwargs['run_id'], params=params)
def remove(conn_id):
if conn_id.startswith('vault'):
return
print(f"Removing conneciton {conn_id}")
session = settings.Session()
for con in session.query(Connection).all():
print(con)
session.query(Connection).filter(Connection.conn_id == conn_id).delete()
session.commit()
def get_conn_id(**kwargs):
ti = kwargs['ti']
conn_id = ti.xcom_pull(key='return_value', task_ids='setup_connection')
return conn_id
@dag(default_args=def_args, schedule_interval=None, start_date=days_ago(2), tags=['example'])
def conn_decorator():
@task() @task()
def doing_nothing(conn_id, **kwargs): def doing_nothing(**kwargs):
conn_id = get_conn_id(**kwargs)
print(f"Just doing nothing with {conn_id}") print(f"Just doing nothing with {conn_id}")
ssh_hook = get_connection(conn_id=conn_id)
ssh_hook = get_connection(conn_id=conn_id)
with ssh_hook.get_conn() as ssh_client: with ssh_hook.get_conn() as ssh_client:
sftp_client = ssh_client.open_sftp() sftp_client = ssh_client.open_sftp()
print("Connected") print("Connected")
...@@ -89,22 +111,11 @@ def conn_decorator(): ...@@ -89,22 +111,11 @@ def conn_decorator():
return conn_id return conn_id
@task() conn_id = PythonOperator(python_callable=setup, task_id='setup_connection')
def remove(conn_id): dno = doing_nothing()
if conn_id.startswith('vault'): en = PythonOperator(python_callable=remove, op_kwargs={'conn_id': dno}, task_id='cleanup')
return
print(f"Removing conneciton {conn_id}")
session = settings.Session()
for con in session.query(Connection).all():
print(con)
session.query(Connection).filter(Connection.conn_id == conn_id).delete()
session.commit()
conn_id = setup() conn_id >> dno >> en
conn_id = doing_nothing(conn_id=conn_id)
remove(conn_id)
dag = conn_decorator() dag = conn_decorator()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment