from airflow import settings from airflow.decorators import dag, task from airflow.models.connection import Connection from airflow.providers.ssh.hooks.ssh import SSHHook from airflow.models import Variable from airflow.utils.dates import days_ago import os from b2shareoperator import (download_file, get_file_list, get_object_md, get_objects) default_args = { 'owner': 'airflow', } @dag(default_args=default_args, schedule_interval=None, start_date=days_ago(2), tags=['example']) def taskflow_example(): @task def setup(**kwargs): print(f"Setting up the connection") params = kwargs['params'] rrid = kwargs['run_id'] host = params.get('host') port = params.get('port', 2222) user = params.get('login', 'eflows') key = params.get('key') conn_id = f"tmp_connection_{rrid}" extra = {"private_key": key} conn = Connection( conn_id=conn_id, conn_type='ssh', description='Automatically generated Connection', host=host, login=user, port=port, extra=extra ) session = settings.Session() session.add(conn) session.commit() print(f"Connection {conn_id} created") return conn_id @task(multiple_outputs=True) def extract(conn_id, **kwargs): connection = Connection.get_connection_from_secrets('default_b2share') server = connection.get_uri() print(f"Rereiving data from {server}") params = kwargs['params'] if 'oid' not in params: # {"oid":"b38609df2b334ea296ea1857e568dbea"} print("Missing object id in pipeline parameters") lst = get_objects(server=server) flist = {o['id']: [f['key'] for f in o['files']] for o in lst} print(f"Objects on server: {flist}") return -1 # non zero exit code is a task failure oid = params['oid'] obj = get_object_md(server=server, oid=oid) print(f"Retrieved object {oid}: {obj}") flist = get_file_list(obj) return flist @task(multiple_outputs=True) def transform(flist: dict): name_mappings = {} tmp_dir = Variable.get("working_dir", default_var='/tmp/') print(f"Local working dir is: {tmp_dir}") for fname, url in flist.items(): print(f"Processing: {fname} --> {url}") tmpname = download_file(url=url, target_dir=tmp_dir) name_mappings[fname] = tmpname return name_mappings @task() def load(connection_id, files: dict, **kwargs): print(f"Total files downloaded: {len(files)}") params = kwargs['params'] target = params.get('target', '/tmp/') print(f"Using {connection_id} connection") ssh_hook = SSHHook(ssh_conn_id=connection_id) with ssh_hook.get_conn() as ssh_client: sftp_client = ssh_client.open_sftp() for [truename, local] in files.items(): print(f"Copying {local} --> {connection_id}:{os.path.join(target, truename)}") sftp_client.put(local, os.path.join(target, truename)) # or separate cleanup task? os.unlink(local) return connection_id @task() def remove(conn_id): print(f"Removing conneciton {conn_id}") session = settings.Session() for con in session.query(Connection).all(): print(con) session.query(Connection).filter(Connection.conn_id == conn_id).delete() session.commit() conn_id = setup() data = extract(conn_id) files = transform(data) ucid = load(connection_id = conn_id, files=files) remove(conn_id=ucid) dag = taskflow_example()