diff --git a/dags/conn_deco.py b/dags/conn_deco.py index b773671c6e2ad75437e8b9257378d099832e8b9e..1e88382f14c44b5b86c0a058a0b9724bd6ef2b8f 100644 --- a/dags/conn_deco.py +++ b/dags/conn_deco.py @@ -5,6 +5,7 @@ from airflow.decorators import dag, task from airflow.providers.ssh.hooks.ssh import SSHHook from airflow.models.connection import Connection from airflow.utils.dates import days_ago +from airflow.providers.hashicorp.hooks.vault import VaultHook def_args = { 'owner': 'airflow', @@ -13,17 +14,8 @@ def_args = { 'email_on_retry': False, 'retries': 1, 'retry_delay': timedelta(minutes=5) - } - -@dag(default_args=def_args, schedule_interval=None, start_date=days_ago(2), tags=['example']) -def conn_decorator(): - @task - def setup(**kwargs): - print(f"Setting up the connection") - - params = kwargs['params'] - rrid = kwargs['run_id'] +def create_temp_connection(rrid, params): host = params.get('host') port = params.get('port', 2222) user = params.get('login', 'eflows') @@ -47,21 +39,61 @@ def conn_decorator(): print(f"Connection {conn_id} created") return conn_id +def get_connection(conn_id): + if conn_id.startswith('vault'): + vault_hook = VaultHook(vault_conn_id='my_vault') + con = vault_hook.get_secret(secret_path=f"connections/{conn_id[6:]}") + print(f"Got some values from vault {list(con.keys())}") + + # for now SSH is hardcoded + hook = SSHHook(remote_host=con['host'], port=int(con['port']), username=con['login']) + #key in vault should be in form of formated string: + #-----BEGIN OPENSSH PRIVATE KEY----- + #b3BlbnNzaC1rZXktdjEAAAAA + #.... + hook.pkey = hook._pkey_from_private_key(private_key=con['key']) + return hook + + # otherwise use previously created temp connection + return SSHHook(ssh_conn_id=conn_id) + + + +@dag(default_args=def_args, schedule_interval=None, start_date=days_ago(2), tags=['example']) +def conn_decorator(): + @task + def setup(**kwargs): + print(f"Setting up the connection") + + params = kwargs['params'] + if 'vault_id' in params: + print('Retrieving connection details from vault') + return f"vault_{params['vault_id']}" + + # otherwise use creds provided in request + return create_temp_connection(rrid = kwargs['run_id'], parms=params) + + @task() def doing_nothing(conn_id, **kwargs): print(f"Just doing nothing with {conn_id}") - params = kwargs['params'] - print(f"This task recieved following kwargs: {params}") + - ssh_hook = SSHHook(ssh_conn_id=conn_id) + ssh_hook = get_connection(conn_id=conn_id) with ssh_hook.get_conn() as ssh_client: sftp_client = ssh_client.open_sftp() print("Connected") + lst = sftp_client.listdir(path='/tmp/') + for f in lst: + print(f) return conn_id @task() def remove(conn_id): + if conn_id.startswith('vault'): + return + print(f"Removing conneciton {conn_id}") session = settings.Session() for con in session.query(Connection).all():