diff --git a/dags/datacat_connection_sync.py b/dags/datacat_connection_sync.py new file mode 100644 index 0000000000000000000000000000000000000000..8a8f096dc29d94d333aef39b9c61432698f7fe44 --- /dev/null +++ b/dags/datacat_connection_sync.py @@ -0,0 +1,119 @@ + +from typing import Dict +from airflow.decorators import dag, task +from airflow.models.connection import Connection +from airflow.utils.dates import days_ago +from airflow import settings +import logging +from sqlalchemy.orm.session import Session as SASession +from datacat_integration.secrets import DataCatConnectionWithSecrets + +from datacat_integration.hooks import DataCatalogHook + +default_args = { + 'owner': 'airflow', +} + +connections_type = "airflow_connections" +substring_start = len(connections_type) + 1 +substring_end = substring_start + 36 # length of a UUID4 + +log = logging.getLogger(__name__) + +def get_conn_name(datacat_type: str, oid: str): + return "{}/{}-connection".format(datacat_type, oid) + +def get_normal_or_secret_property(key: str, props: Dict[str,str], secrets: Dict[str, str], default_value = None): + return props.get(key, secrets.get(key, default_value)) + + +def get_connection(hook: DataCatalogHook, datacat_type: str, oid: str): + conn_id = get_conn_name(datacat_type, oid) + secrets_connection = DataCatConnectionWithSecrets(hook.connection.url, hook.connection.user, hook.connection._password) + datacat_entry: Dict[str,str] = secrets_connection.get_object(datacat_type, oid)['metadata'] + datacat_entry_secrets = secrets_connection.get_all_secret_key_value(datacat_type, oid) + extra={} + predefined_keys = ['conn_type', 'description', 'host', 'login', 'password', 'schema', 'port'] + # build extra from non-predefined keys + for key in datacat_entry: + if key not in predefined_keys: + extra[key] = datacat_entry[key] + + for key in datacat_entry_secrets: + if key not in predefined_keys: + extra[key] = datacat_entry_secrets[key] + + + return Connection( + conn_id=conn_id, + conn_type=get_normal_or_secret_property('conn_type', datacat_entry, datacat_entry_secrets), + description=get_normal_or_secret_property('description', datacat_entry, datacat_entry_secrets, 'Automatically generated Connection from the datacatalog object {}/{}'.format(connections_type, oid)), + host=get_normal_or_secret_property('host', datacat_entry, datacat_entry_secrets), + login=get_normal_or_secret_property('login', datacat_entry, datacat_entry_secrets), + password=get_normal_or_secret_property('password', datacat_entry, datacat_entry_secrets), + schema=get_normal_or_secret_property('schema', datacat_entry, datacat_entry_secrets), + port=int(get_normal_or_secret_property('port', datacat_entry, datacat_entry_secrets)), + extra=extra + ) + + +@dag(default_args=default_args, schedule_interval='@hourly', start_date=days_ago(1), tags=['dls-service-dag']) +def sync_connections(): + + @task + def list_catalog_connections(**kwargs): + hook = DataCatalogHook("datacatalog") + objects = hook.list_type(connections_type) + oid_list = [element[1] for element in objects] + return oid_list + + @task + def list_airflow_connections(**kwargs): + session : SASession = settings.Session() + conns = session.query(Connection).filter(Connection.conn_id.like("{}/%-connection".format(connections_type))) + oid_list = [conn.conn_id[substring_start:substring_end] for conn in conns] + return oid_list + + @task + def get_add_list(catalog_connections, airflow_connections, **kwargs): + return list(set(catalog_connections).difference(airflow_connections)) + + @task + def get_remove_list(catalog_connections, airflow_connections, **kwargs): + return list(set(airflow_connections).difference(catalog_connections)) + + @task + def remove_connections(oid_list, **kwargs): + log.info("Going to remove from conections: " + ','.join(oid_list)) + session : SASession = settings.Session() + for oid in oid_list: + session.query(Connection).filter(Connection.conn_id == get_conn_name(connections_type, oid)).delete() + session.commit() + + @task + def add_connections(oid_list, **kwargs): + log.info("Going to add to conections: " + ','.join(oid_list)) + hook = DataCatalogHook("datacatalog") + connections = [] + for oid in oid_list: + connections.append(get_connection(hook, connections_type, oid)) + + session = settings.Session() + # no check for existsnce necessary, since it is handled by get_add_list() + for conn in connections: + session.add(conn) + + session.commit() + + cat_conn = list_catalog_connections() + air_conn = list_airflow_connections() + + add_list = get_add_list(cat_conn, air_conn) + remove_list = get_remove_list(cat_conn, air_conn) + + add_connections(add_list) + + remove_connections(remove_list) + + +dag = sync_connections()