Skip to content
Snippets Groups Projects
Select Git revision
  • main
  • main-patch-8ae7
  • second
3 results

utility.py

Blame
  • hooks.py 1.83 KiB
    from typing import Any
    from airflow.providers.http.hooks.http import HttpHook
    from airflow import settings
    from airflow.models.connection import Connection
    from requests import Session
    from typing import Optional, Dict
    
    from datacat_integration.connection import DataCatConnection, DataCatalogEntry, get_connection_from_entry
    
    from datacat_integration.auth import BearerAuth
    
    class DataCatalogHook(HttpHook):
        connection: DataCatConnection = None
    
        def __init__(self, http_conn_id: str = 'datacatalog') -> None:
            super().__init__(method="GET", http_conn_id=http_conn_id)
            conn = super().get_conn()
            self.connection = DataCatConnection(catalog_url=self.base_url, username=conn.auth.username, password=conn.auth.password)
    
        def get_conn(self, headers: Optional[Dict[Any, Any]] = None) -> Session:
            conn = super().get_conn(headers=headers)
            self.connection = DataCatConnection(catalog_url=self.base_url, username=conn.auth.username, password=conn.auth.password)
            conn.auth = BearerAuth(self.connection.get_token())
            return conn
    
        def get_entry(self, datacat_type: str, oid: str):
            return self.connection.get_entry(datacat_type, oid)
    
        def create_entry(self, datacat_type: str, entry: DataCatalogEntry):
            return self.connection.create_entry(datacat_type, entry)
    
        def list_type(self, datacat_type: str):
            return self.connection.list_type(datacat_type)
    
        def create_get_entry_connection(self, datacat_type: str, oid: str):
            entry = self.get_entry(datacat_type, oid)
            conn = get_connection_from_entry(entry, datacat_type, oid)
            
            session = settings.Session()
    
            conn_name = session.query(Connection).filter(Connection.conn_id == conn.conn_id).first()
            if str(conn_name) != str (conn.conn_id):
                session.add(conn)
                session.commit()