Skip to content
Snippets Groups Projects
Commit fe17a24d authored by Jedrzej Rybicki's avatar Jedrzej Rybicki
Browse files

taskflow with connection decorator

parent 34c4eef7
No related branches found
No related tags found
No related merge requests found
from airflow import settings
from airflow.decorators import dag, task
from airflow.models.connection import Connection
from airflow.providers.ssh.hooks.ssh import SSHHook
......@@ -16,8 +17,38 @@ default_args = {
@dag(default_args=default_args, schedule_interval=None, start_date=days_ago(2), tags=['example'])
def taskflow_example():
@task
def setup(**kwargs):
print(f"Setting up the connection")
params = kwargs['params']
rrid = kwargs['run_id']
host = params.get('host')
port = params.get('port', 2222)
user = params.get('login', 'eflows')
key = params.get('key')
conn_id = f"tmp_connection_{rrid}"
extra = {"private_key": key}
conn = Connection(
conn_id=conn_id,
conn_type='ssh',
description='Automatically generated Connection',
host=host,
login=user,
port=port,
extra=extra
)
session = settings.Session()
session.add(conn)
session.commit()
print(f"Connection {conn_id} created")
return conn_id
@task(multiple_outputs=True)
def extract(**kwargs):
def extract(conn_id, **kwargs):
connection = Connection.get_connection_from_secrets('default_b2share')
server = connection.get_uri()
print(f"Rereiving data from {server}")
......@@ -50,11 +81,12 @@ def taskflow_example():
return name_mappings
@task()
def load(files: dict, **kwargs):
def load(connection_id, files: dict, **kwargs):
print(f"Total files downloaded: {len(files)}")
params = kwargs['params']
target = params.get('target', '/tmp/')
connection_id = params.get('connection', 'default_ssh')
print(f"Using {connection_id} connection")
ssh_hook = SSHHook(ssh_conn_id=connection_id)
with ssh_hook.get_conn() as ssh_client:
......@@ -65,9 +97,22 @@ def taskflow_example():
# or separate cleanup task?
os.unlink(local)
data = extract()
files = transform(data)
load(files)
return connection_id
@task()
def remove(conn_id):
print(f"Removing conneciton {conn_id}")
session = settings.Session()
for con in session.query(Connection).all():
print(con)
session.query(Connection).filter(Connection.conn_id == conn_id).delete()
session.commit()
conn_id = setup()
data = extract(conn_id)
files = transform(data)
ucid = load(connection_id = conn_id, files=files)
remove(conn_id=ucid)
dag = taskflow_example()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment