diff --git a/src/datacat_integration/connection.py b/src/datacat_integration/connection.py index 0b50e3cbe24a08ccd447eb3be39a122879da28e4..f936305af1120f30a2b401a1dd1c6b3542581eb8 100644 --- a/src/datacat_integration/connection.py +++ b/src/datacat_integration/connection.py @@ -11,17 +11,20 @@ def get_connection_from_entry(data: Dict[str, Any], datacat_type: str, oid: str) """returns an airflow connection from the data provided in the datacat entry.""" # delay import to prevent circular dependencies during config startup from airflow.models.connection import Connection - conn_type = data['metadata'].get('conn_type') - host = data['metadata'].get('host') - port = data['metadata'].get('port') - schema = data['metadata'].get('schema') + metadata = data['metadata'] + if isinstance(metadata, str): + raise TypeError("Datacat entry metadata is a string instead of a dict.") + conn_type = metadata.get('conn_type') + host = metadata.get('host') + port = metadata.get('port') + schema = metadata.get('schema') conn_id = f"{datacat_type}/{oid}-connection" # set all remaining metadata as extra extra = {} - for key in data['metadata'].keys(): + for key in metadata.keys(): if key in ['conn_type', 'host', 'port', 'schema']: continue - extra[key] = data['metadata'][key] + extra[key] = metadata[key] return Connection( conn_id=conn_id, diff --git a/src/datacat_integration/secrets.py b/src/datacat_integration/secrets.py index b409c0f2e2962462e03924fd00af981b71732401..3d20e5709f40730a0a86e39410c2149dfbd474df 100644 --- a/src/datacat_integration/secrets.py +++ b/src/datacat_integration/secrets.py @@ -11,9 +11,6 @@ connection_backend_type = "airflow_connections" log = logging.getLogger(__name__) -class Empty: - pass - def get_connection_with_secrets_from_entry(data: Dict[str, Any], secrets: Dict[str, str] , datacat_type: str, oid: str): """returns an aiflow connection from the data provided in the datacat entry and the secrets.""" conn = get_connection_from_entry(data, datacat_type, oid) @@ -88,7 +85,7 @@ class DatacatSecretsBackend(BaseSecretsBackend): log.debug(f"Get connection: {conn_id}") secrets_conn = DataCatConnectionWithSecrets(self.url, self.user, self.password) - data: Dict[str,str] = DataCatalogEntry.from_json(secrets_conn.get_entry(connection_backend_type, conn_id)).__dict__ + data: Dict[str,str] = json.loads(secrets_conn.get_entry(connection_backend_type, conn_id)) secrets: Dict[str,str] = json.loads(secrets_conn.get_all_secret_key_value(connection_backend_type, conn_id)) conn = get_connection_with_secrets_from_entry(data, secrets, connection_backend_type, conn_id) return conn \ No newline at end of file diff --git a/tests/test_connection.py b/tests/test_connection.py index 4b23a15d086303f6491e44fa57bc83c106c6edaa..9e878bb8317c1f461d701b363635628f5464cd95 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,8 +1,34 @@ from unittest import TestCase -import os, random +import os, random, json from dotenv import load_dotenv -from datacat_integration.connection import DataCatalogEntry, DataCatConnection +from datacat_integration.connection import DataCatalogEntry, DataCatConnection, get_connection_from_entry + + +class GetConnectionTest(TestCase): + def setUp(self): + pass + + def test_get_connection(self): + data = {"name" : "foo", "url" : "bar", "metadata" : {"a" : "b", "c" : "d", "conn_type" : "http", "port" : "443", "host" : "test.com", "schema" : "https"}} + datacat_type = "airflow_connections" + oid = "860355e9-975f-4253-9421-1815e20c879b" + + conn = get_connection_from_entry(data, datacat_type, oid) + + self.assertEqual(conn.conn_id, f"{datacat_type}/{oid}-connection") + self.assertEqual(conn.conn_type, "http") + self.assertEqual(str(conn.port), "443") + self.assertEqual(json.loads(conn.get_extra())["a"], "b") + + def test_get_wrong_metadata(self): + # metadata is a json string instead of a dict + data = {"name" : "foo", "url" : "bar", "metadata" : '{"a" : "b", "c" : "d", "conn_type" : "http", "port" : "443", "host" : "test.com", "schema" : "https"}'} + datacat_type = "airflow_connections" + oid = "860355e9-975f-4253-9421-1815e20c879b" + + self.assertRaises(TypeError, get_connection_from_entry, data, datacat_type, oid) + class EntryTest(TestCase):