diff --git a/dags/taskflow.py b/dags/taskflow.py index a410957376484a5efbdb8b73b6c47e51ce14a6d7..810dd07012c951bbedd8cd412e8da811b37a49b8 100644 --- a/dags/taskflow.py +++ b/dags/taskflow.py @@ -1,12 +1,12 @@ -from airflow import settings +from airflow.operators.python import PythonOperator from airflow.decorators import dag, task from airflow.models.connection import Connection -from airflow.providers.ssh.hooks.ssh import SSHHook from airflow.models import Variable from airflow.utils.dates import days_ago import os +from decors import get_connection, remove, setup from b2shareoperator import (download_file, get_file_list, get_object_md, get_objects) @@ -18,35 +18,6 @@ default_args = { @dag(default_args=default_args, schedule_interval=None, start_date=days_ago(2), tags=['example']) def taskflow_example(): - @task - def setup(**kwargs): - print(f"Setting up the connection") - - params = kwargs['params'] - rrid = kwargs['run_id'] - host = params.get('host') - port = params.get('port', 2222) - user = params.get('login', 'eflows') - key = params.get('key') - - conn_id = f"tmp_connection_{rrid}" - extra = {"private_key": key} - conn = Connection( - conn_id=conn_id, - conn_type='ssh', - description='Automatically generated Connection', - host=host, - login=user, - port=port, - extra=extra - ) - - session = settings.Session() - session.add(conn) - session.commit() - print(f"Connection {conn_id} created") - return conn_id - @task(multiple_outputs=True) def extract(conn_id, **kwargs): connection = Connection.get_connection_from_secrets('default_b2share') @@ -73,7 +44,7 @@ def taskflow_example(): name_mappings = {} tmp_dir = Variable.get("working_dir", default_var='/tmp/') print(f"Local working dir is: {tmp_dir}") - + for fname, url in flist.items(): print(f"Processing: {fname} --> {url}") tmpname = download_file(url=url, target_dir=tmp_dir) @@ -85,34 +56,33 @@ def taskflow_example(): print(f"Total files downloaded: {len(files)}") params = kwargs['params'] target = params.get('target', '/tmp/') - + print(f"Using {connection_id} connection") - - ssh_hook = SSHHook(ssh_conn_id=connection_id) + ssh_hook = get_connection(conn_id=conn_id) + with ssh_hook.get_conn() as ssh_client: sftp_client = ssh_client.open_sftp() for [truename, local] in files.items(): - print(f"Copying {local} --> {connection_id}:{os.path.join(target, truename)}") + print( + f"Copying {local} --> {connection_id}:{os.path.join(target, truename)}") sftp_client.put(local, os.path.join(target, truename)) # or separate cleanup task? os.unlink(local) return connection_id - @task() - def remove(conn_id): - print(f"Removing conneciton {conn_id}") - session = settings.Session() - for con in session.query(Connection).all(): - print(con) - - session.query(Connection).filter(Connection.conn_id == conn_id).delete() - session.commit() - - conn_id = setup() - data = extract(conn_id) - files = transform(data) - ucid = load(connection_id = conn_id, files=files) - remove(conn_id=ucid) + conn_id = PythonOperator(python_callable=setup, task_id='setup_connection') + # another way of mixing taskflow and classical api: + a_id = conn_id.output['return_value'] + + data = extract(conn_id=a_id) + files = transform(flist=data) + ucid = load(connection_id=a_id, files=files) + + en = PythonOperator(python_callable=remove, op_kwargs={ + 'conn_id': a_id}, task_id='cleanup') + + conn_id >> data >> files >> ucid >> en + dag = taskflow_example()