Skip to content
Snippets Groups Projects
Commit fe17a24d authored by Jedrzej Rybicki's avatar Jedrzej Rybicki
Browse files

taskflow with connection decorator

parent 34c4eef7
No related branches found
No related tags found
No related merge requests found
from airflow import settings
from airflow.decorators import dag, task from airflow.decorators import dag, task
from airflow.models.connection import Connection from airflow.models.connection import Connection
from airflow.providers.ssh.hooks.ssh import SSHHook from airflow.providers.ssh.hooks.ssh import SSHHook
...@@ -16,8 +17,38 @@ default_args = { ...@@ -16,8 +17,38 @@ default_args = {
@dag(default_args=default_args, schedule_interval=None, start_date=days_ago(2), tags=['example']) @dag(default_args=default_args, schedule_interval=None, start_date=days_ago(2), tags=['example'])
def taskflow_example(): def taskflow_example():
@task
def setup(**kwargs):
print(f"Setting up the connection")
params = kwargs['params']
rrid = kwargs['run_id']
host = params.get('host')
port = params.get('port', 2222)
user = params.get('login', 'eflows')
key = params.get('key')
conn_id = f"tmp_connection_{rrid}"
extra = {"private_key": key}
conn = Connection(
conn_id=conn_id,
conn_type='ssh',
description='Automatically generated Connection',
host=host,
login=user,
port=port,
extra=extra
)
session = settings.Session()
session.add(conn)
session.commit()
print(f"Connection {conn_id} created")
return conn_id
@task(multiple_outputs=True) @task(multiple_outputs=True)
def extract(**kwargs): def extract(conn_id, **kwargs):
connection = Connection.get_connection_from_secrets('default_b2share') connection = Connection.get_connection_from_secrets('default_b2share')
server = connection.get_uri() server = connection.get_uri()
print(f"Rereiving data from {server}") print(f"Rereiving data from {server}")
...@@ -50,11 +81,12 @@ def taskflow_example(): ...@@ -50,11 +81,12 @@ def taskflow_example():
return name_mappings return name_mappings
@task() @task()
def load(files: dict, **kwargs): def load(connection_id, files: dict, **kwargs):
print(f"Total files downloaded: {len(files)}") print(f"Total files downloaded: {len(files)}")
params = kwargs['params'] params = kwargs['params']
target = params.get('target', '/tmp/') target = params.get('target', '/tmp/')
connection_id = params.get('connection', 'default_ssh')
print(f"Using {connection_id} connection")
ssh_hook = SSHHook(ssh_conn_id=connection_id) ssh_hook = SSHHook(ssh_conn_id=connection_id)
with ssh_hook.get_conn() as ssh_client: with ssh_hook.get_conn() as ssh_client:
...@@ -65,9 +97,22 @@ def taskflow_example(): ...@@ -65,9 +97,22 @@ def taskflow_example():
# or separate cleanup task? # or separate cleanup task?
os.unlink(local) os.unlink(local)
data = extract() return connection_id
files = transform(data)
load(files) @task()
def remove(conn_id):
print(f"Removing conneciton {conn_id}")
session = settings.Session()
for con in session.query(Connection).all():
print(con)
session.query(Connection).filter(Connection.conn_id == conn_id).delete()
session.commit()
conn_id = setup()
data = extract(conn_id)
files = transform(data)
ucid = load(connection_id = conn_id, files=files)
remove(conn_id=ucid)
dag = taskflow_example() dag = taskflow_example()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment